You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2017/08/30 21:30:58 UTC

[1/2] beam git commit: [BEAM-1347] Create value state, combining state, and bag state views over the BagUserState.

Repository: beam
Updated Branches:
  refs/heads/master f6c840533 -> 585440d22


[BEAM-1347] Create value state, combining state, and bag state views over the BagUserState.

Also bind the state persistence to the end of finishBundle.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/e0f628cc
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/e0f628cc
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/e0f628cc

Branch: refs/heads/master
Commit: e0f628cc7fbf6cbfb46825d6ee7bbc29e0bd66f5
Parents: f6c8405
Author: Luke Cwik <lc...@google.com>
Authored: Tue Aug 29 10:45:04 2017 -0700
Committer: Luke Cwik <lc...@google.com>
Committed: Wed Aug 30 14:30:27 2017 -0700

----------------------------------------------------------------------
 .../apache/beam/fn/harness/FnApiDoFnRunner.java | 380 ++++++++++++++++++-
 .../beam/fn/harness/FnApiDoFnRunnerTest.java    | 229 +++++++++++
 .../fn/harness/state/FakeBeamFnStateClient.java |   2 +-
 3 files changed, 605 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/e0f628cc/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index d325bb2..c361647 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -18,45 +18,77 @@
 package org.apache.beam.fn.harness;
 
 import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Preconditions.checkState;
 
 import com.google.auto.service.AutoService;
+import com.google.common.base.Suppliers;
 import com.google.common.collect.Collections2;
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableMultimap;
 import com.google.common.collect.Multimap;
 import com.google.protobuf.ByteString;
+import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Collection;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
 import java.util.Map;
 import java.util.Objects;
 import java.util.function.Consumer;
+import java.util.function.Function;
 import java.util.function.Supplier;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
 import org.apache.beam.fn.harness.fn.ThrowingConsumer;
 import org.apache.beam.fn.harness.fn.ThrowingRunnable;
+import org.apache.beam.fn.harness.state.BagUserState;
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
+import org.apache.beam.fn.v1.BeamFnApi.StateKey;
+import org.apache.beam.fn.v1.BeamFnApi.StateRequest;
+import org.apache.beam.fn.v1.BeamFnApi.StateRequest.Builder;
 import org.apache.beam.runners.core.DoFnRunner;
 import org.apache.beam.runners.core.construction.ParDoTranslation;
 import org.apache.beam.runners.dataflow.util.DoFnInfo;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.common.runner.v1.RunnerApi;
 import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.CombiningState;
+import org.apache.beam.sdk.state.MapState;
+import org.apache.beam.sdk.state.ReadableState;
+import org.apache.beam.sdk.state.ReadableStates;
+import org.apache.beam.sdk.state.SetState;
 import org.apache.beam.sdk.state.State;
+import org.apache.beam.sdk.state.StateBinder;
+import org.apache.beam.sdk.state.StateContext;
+import org.apache.beam.sdk.state.StateSpec;
 import org.apache.beam.sdk.state.TimeDomain;
 import org.apache.beam.sdk.state.Timer;
+import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.state.WatermarkHoldState;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.DoFn.OnTimerContext;
 import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.util.CombineFnUtil;
 import org.apache.beam.sdk.util.SerializableUtils;
 import org.apache.beam.sdk.util.UserCodeException;
 import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.WindowingStrategy;
@@ -141,7 +173,13 @@ public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Outp
       @SuppressWarnings({"unchecked", "rawtypes"})
       DoFnRunner<InputT, OutputT> runner = new FnApiDoFnRunner<>(
           pipelineOptions,
+          beamFnStateClient,
+          pTransformId,
+          processBundleInstructionId,
           doFnInfo.getDoFn(),
+          WindowedValue.getFullCoder(
+              doFnInfo.getInputCoder(),
+              doFnInfo.getWindowingStrategy().getWindowFn().windowCoder()),
           (Collection<ThrowingConsumer<WindowedValue<OutputT>>>) (Collection)
               tagToOutputMap.get(doFnInfo.getOutputMap().get(doFnInfo.getMainOutput())),
           tagToOutputMap,
@@ -162,42 +200,68 @@ public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Outp
   //////////////////////////////////////////////////////////////////////////////////////////////////
 
   private final PipelineOptions pipelineOptions;
+  private final BeamFnStateClient beamFnStateClient;
+  private final String ptransformId;
+  private final Supplier<String> processBundleInstructionId;
   private final DoFn<InputT, OutputT> doFn;
+  private final WindowedValueCoder<InputT> inputCoder;
   private final Collection<ThrowingConsumer<WindowedValue<OutputT>>> mainOutputConsumers;
   private final Multimap<TupleTag<?>, ThrowingConsumer<WindowedValue<?>>> outputMap;
+  private final WindowingStrategy windowingStrategy;
+  private final DoFnSignature doFnSignature;
   private final DoFnInvoker<InputT, OutputT> doFnInvoker;
+  private final StateBinder stateBinder;
   private final StartBundleContext startBundleContext;
   private final ProcessBundleContext processBundleContext;
   private final FinishBundleContext finishBundleContext;
-  private final WindowingStrategy windowingStrategy;
-  private final DoFnSignature doFnSignature;
+  private final Collection<ThrowingRunnable> stateFinalizers;
 
   /**
-   * The lifetime of this member is only valid during {@link #processElement(WindowedValue)}.
+   * The lifetime of this member is only valid during {@link #processElement}
+   * and is null otherwise.
    */
   private WindowedValue<InputT> currentElement;
 
   /**
-   * The lifetime of this member is only valid during {@link #processElement(WindowedValue)}.
+   * The lifetime of this member is only valid during {@link #processElement}
+   * and is null otherwise.
    */
   private BoundedWindow currentWindow;
 
+  /**
+   * This member should only be accessed indirectly by calling
+   * {@link #createOrUseCachedBagUserStateKey} and is only valid during {@link #processElement}
+   * and is null otherwise.
+   */
+  private StateKey.BagUserState cachedPartialBagUserStateKey;
+
+
   FnApiDoFnRunner(
       PipelineOptions pipelineOptions,
+      BeamFnStateClient beamFnStateClient,
+      String ptransformId,
+      Supplier<String> processBundleInstructionId,
       DoFn<InputT, OutputT> doFn,
+      WindowedValueCoder<InputT> inputCoder,
       Collection<ThrowingConsumer<WindowedValue<OutputT>>> mainOutputConsumers,
       Multimap<TupleTag<?>, ThrowingConsumer<WindowedValue<?>>> outputMap,
       WindowingStrategy windowingStrategy) {
     this.pipelineOptions = pipelineOptions;
+    this.beamFnStateClient = beamFnStateClient;
+    this.ptransformId = ptransformId;
+    this.processBundleInstructionId = processBundleInstructionId;
     this.doFn = doFn;
+    this.inputCoder = inputCoder;
     this.mainOutputConsumers = mainOutputConsumers;
     this.outputMap = outputMap;
     this.windowingStrategy = windowingStrategy;
     this.doFnSignature = DoFnSignatures.signatureForDoFn(doFn);
     this.doFnInvoker = DoFnInvokers.invokerFor(doFn);
+    this.stateBinder = new BeamFnStateBinder();
     this.startBundleContext = new StartBundleContext();
     this.processBundleContext = new ProcessBundleContext();
     this.finishBundleContext = new FinishBundleContext();
+    this.stateFinalizers = new ArrayList<>();
   }
 
   @Override
@@ -218,6 +282,7 @@ public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Outp
     } finally {
       currentElement = null;
       currentWindow = null;
+      cachedPartialBagUserStateKey = null;
     }
   }
 
@@ -233,6 +298,18 @@ public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Outp
   @Override
   public void finishBundle() {
     doFnInvoker.invokeFinishBundle(finishBundleContext);
+
+    // Persist all dirty state cells
+    try {
+      for (ThrowingRunnable runnable : stateFinalizers) {
+        runnable.run();
+      }
+    } catch (InterruptedException e) {
+      Thread.currentThread().interrupt();
+      throw new IllegalStateException(e);
+    } catch (Exception e) {
+      throw new IllegalStateException(e);
+    }
   }
 
   /**
@@ -367,7 +444,15 @@ public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Outp
 
     @Override
     public State state(String stateId) {
-      throw new UnsupportedOperationException("TODO: Add support for state");
+      StateDeclaration stateDeclaration = doFnSignature.stateDeclarations().get(stateId);
+      checkNotNull(stateDeclaration, "No state declaration found for %s", stateId);
+      StateSpec<?> spec;
+      try {
+        spec = (StateSpec<?>) stateDeclaration.field().get(doFn);
+      } catch (IllegalAccessException e) {
+        throw new RuntimeException(e);
+      }
+      return spec.bind(stateId, stateBinder);
     }
 
     @Override
@@ -545,4 +630,289 @@ public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Outp
           WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
     }
   }
+
+  /**
+   * A {@link StateBinder} that uses the Beam Fn State API to read and write user state.
+   *
+   * <p>TODO: Add support for {@link #bindMap} and {@link #bindSet}. Note that
+   * {@link #bindWatermark} should never be implemented.
+   */
+  private class BeamFnStateBinder implements StateBinder {
+    private final Map<StateKey.BagUserState, Object> stateObjectCache = new HashMap<>();
+
+    @Override
+    public <T> ValueState<T> bindValue(String id, StateSpec<ValueState<T>> spec, Coder<T> coder) {
+      return (ValueState<T>) stateObjectCache.computeIfAbsent(
+          createOrUseCachedBagUserStateKey(id),
+          new Function<StateKey.BagUserState, Object>() {
+        @Override
+        public Object apply(StateKey.BagUserState s) {
+          return new ValueState<T>() {
+            private final BagUserState<T> impl = createBagUserState(id, coder);
+
+            @Override
+            public void clear() {
+              impl.clear();
+            }
+
+            @Override
+            public void write(T input) {
+              impl.clear();
+              impl.append(input);
+            }
+
+            @Override
+            public T read() {
+              Iterator<T> value = impl.get().iterator();
+              if (value.hasNext()) {
+                return value.next();
+              } else {
+                return null;
+              }
+            }
+
+            @Override
+            public ValueState<T> readLater() {
+              // TODO: Support prefetching.
+              return this;
+            }
+          };
+        }
+      });
+    }
+
+    @Override
+    public <T> BagState<T> bindBag(String id, StateSpec<BagState<T>> spec, Coder<T> elemCoder) {
+      return (BagState<T>) stateObjectCache.computeIfAbsent(
+          createOrUseCachedBagUserStateKey(id),
+          new Function<StateKey.BagUserState, Object>() {
+        @Override
+        public Object apply(StateKey.BagUserState s) {
+          return new BagState<T>() {
+            private final BagUserState<T> impl = createBagUserState(id, elemCoder);
+
+            @Override
+            public void add(T value) {
+              impl.append(value);
+            }
+
+            @Override
+            public ReadableState<Boolean> isEmpty() {
+              return ReadableStates.immediate(!impl.get().iterator().hasNext());
+            }
+
+            @Override
+            public Iterable<T> read() {
+              return impl.get();
+            }
+
+            @Override
+            public BagState<T> readLater() {
+              // TODO: Support prefetching.
+              return this;
+            }
+
+            @Override
+            public void clear() {
+              impl.clear();
+            }
+          };
+        }
+      });
+    }
+
+    @Override
+    public <T> SetState<T> bindSet(String id, StateSpec<SetState<T>> spec, Coder<T> elemCoder) {
+      throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
+    }
+
+    @Override
+    public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(String id,
+        StateSpec<MapState<KeyT, ValueT>> spec, Coder<KeyT> mapKeyCoder,
+        Coder<ValueT> mapValueCoder) {
+      throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
+    }
+
+    @Override
+    public <InputT, AccumT, OutputT> CombiningState<InputT, AccumT, OutputT> bindCombining(
+        String id,
+        StateSpec<CombiningState<InputT, AccumT, OutputT>> spec, Coder<AccumT> accumCoder,
+        CombineFn<InputT, AccumT, OutputT> combineFn) {
+      return (CombiningState<InputT, AccumT, OutputT>) stateObjectCache.computeIfAbsent(
+          createOrUseCachedBagUserStateKey(id),
+          new Function<StateKey.BagUserState, Object>() {
+        @Override
+        public Object apply(StateKey.BagUserState s) {
+          // TODO: Support squashing accumulators depending on whether we know of all
+          // remote accumulators and local accumulators or just local accumulators.
+          return new CombiningState<InputT, AccumT, OutputT>() {
+            private final BagUserState<AccumT> impl = createBagUserState(id, accumCoder);
+
+            @Override
+            public AccumT getAccum() {
+              Iterator<AccumT> iterator = impl.get().iterator();
+              if (iterator.hasNext()) {
+                return iterator.next();
+              }
+              return combineFn.createAccumulator();
+            }
+
+            @Override
+            public void addAccum(AccumT accum) {
+              Iterator<AccumT> iterator = impl.get().iterator();
+
+              // Only merge if there was a prior value
+              if (iterator.hasNext()) {
+                accum = combineFn.mergeAccumulators(ImmutableList.of(iterator.next(), accum));
+                // Since there was a prior value, we need to clear.
+                impl.clear();
+              }
+
+              impl.append(accum);
+            }
+
+            @Override
+            public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
+              return combineFn.mergeAccumulators(accumulators);
+            }
+
+            @Override
+            public CombiningState<InputT, AccumT, OutputT> readLater() {
+              return this;
+            }
+
+            @Override
+            public OutputT read() {
+              Iterator<AccumT> iterator = impl.get().iterator();
+              if (iterator.hasNext()) {
+                return combineFn.extractOutput(iterator.next());
+              }
+              return combineFn.defaultValue();
+            }
+
+            @Override
+            public void add(InputT value) {
+              AccumT newAccumulator = combineFn.addInput(getAccum(), value);
+              impl.clear();
+              impl.append(newAccumulator);
+            }
+
+            @Override
+            public ReadableState<Boolean> isEmpty() {
+              return ReadableStates.immediate(!impl.get().iterator().hasNext());
+            }
+
+            @Override
+            public void clear() {
+              impl.clear();
+            }
+          };
+        }
+      });
+    }
+
+    @Override
+    public <InputT, AccumT, OutputT> CombiningState<InputT, AccumT, OutputT>
+    bindCombiningWithContext(
+        String id,
+        StateSpec<CombiningState<InputT, AccumT, OutputT>> spec,
+        Coder<AccumT> accumCoder,
+        CombineFnWithContext<InputT, AccumT, OutputT> combineFn) {
+      return (CombiningState<InputT, AccumT, OutputT>) stateObjectCache.computeIfAbsent(
+          createOrUseCachedBagUserStateKey(id),
+          new Function<StateKey.BagUserState, Object>() {
+            @Override
+            public Object apply(StateKey.BagUserState s) {
+              return bindCombining(id, spec, accumCoder, CombineFnUtil.bindContext(combineFn,
+                  new StateContext<BoundedWindow>() {
+                    @Override
+                    public PipelineOptions getPipelineOptions() {
+                      return pipelineOptions;
+                    }
+
+                    @Override
+                    public <T> T sideInput(PCollectionView<T> view) {
+                      return processBundleContext.sideInput(view);
+                    }
+
+                    @Override
+                    public BoundedWindow window() {
+                      return currentWindow;
+                    }
+                  }));
+            }
+          });
+    }
+
+    /**
+     * @deprecated The Fn API has no plans to implement WatermarkHoldState as of this writing
+     * and is waiting on resolution of BEAM-2535.
+     */
+    @Override
+    @Deprecated
+    public WatermarkHoldState bindWatermark(String id, StateSpec<WatermarkHoldState> spec,
+        TimestampCombiner timestampCombiner) {
+      throw new UnsupportedOperationException("WatermarkHoldState is unsupported by the Fn API.");
+    }
+
+    private <T> BagUserState<T> createBagUserState(String id, Coder<T> coder) {
+      BagUserState rval = new BagUserState<T>(
+          beamFnStateClient,
+          id,
+          coder,
+          new Supplier<StateRequest.Builder>() {
+            /** Memoizes the partial state key for the lifetime of the {@link BagUserState}. */
+            private final Supplier<StateKey.BagUserState> memoizingSupplier =
+                Suppliers.memoize(() -> createOrUseCachedBagUserStateKey(id))::get;
+
+            @Override
+            public Builder get() {
+              return StateRequest.newBuilder()
+                  .setInstructionReference(processBundleInstructionId.get())
+                  .setStateKey(StateKey.newBuilder()
+                      .setBagUserState(memoizingSupplier.get()));
+            }
+          });
+      stateFinalizers.add(rval::asyncClose);
+      return rval;
+    }
+  }
+
+  /**
+   * Memoizes a partially built {@link StateKey} saving on the encoding cost of the key and
+   * window across multiple state cells for the lifetime of {@link #processElement}.
+   *
+   * <p>This should only be called during {@link #processElement}.
+   */
+  private <K> StateKey.BagUserState createOrUseCachedBagUserStateKey(String id) {
+    if (cachedPartialBagUserStateKey == null) {
+      checkState(currentElement.getValue() instanceof KV,
+          "Accessing state in unkeyed context. Current element is not a KV: %s.",
+          currentElement);
+      checkState(inputCoder.getCoderArguments().get(0) instanceof KvCoder,
+          "Accessing state in unkeyed context. No keyed coder found.");
+
+      ByteString.Output encodedKeyOut = ByteString.newOutput();
+
+      Coder<K> keyCoder = ((KvCoder<K, ?>) inputCoder.getValueCoder()).getKeyCoder();
+      try {
+        keyCoder.encode(((KV<K, ?>) currentElement.getValue()).getKey(), encodedKeyOut);
+      } catch (IOException e) {
+        throw new IllegalStateException(e);
+      }
+
+      ByteString.Output encodedWindowOut = ByteString.newOutput();
+      try {
+        windowingStrategy.getWindowFn().windowCoder().encode(currentWindow, encodedWindowOut);
+      } catch (IOException e) {
+        throw new IllegalStateException(e);
+      }
+
+      cachedPartialBagUserStateKey = StateKey.BagUserState.newBuilder()
+          .setPtransformId(ptransformId)
+          .setKey(encodedKeyOut.toByteString())
+          .setWindow(encodedWindowOut.toByteString()).buildPartial();
+    }
+    return cachedPartialBagUserStateKey.toBuilder().setUserStateId(id).build();
+  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/e0f628cc/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
index ebec608..4aa8080 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -22,6 +22,8 @@ import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWin
 import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.Matchers.empty;
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.fail;
 
@@ -32,22 +34,36 @@ import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Multimap;
 import com.google.protobuf.ByteString;
+import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.ServiceLoader;
 import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar;
 import org.apache.beam.fn.harness.fn.ThrowingConsumer;
 import org.apache.beam.fn.harness.fn.ThrowingRunnable;
+import org.apache.beam.fn.harness.state.FakeBeamFnStateClient;
+import org.apache.beam.fn.v1.BeamFnApi.StateKey;
 import org.apache.beam.runners.core.construction.ParDoTranslation;
 import org.apache.beam.runners.dataflow.util.DoFnInfo;
+import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.common.runner.v1.RunnerApi;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.CombiningState;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
+import org.apache.beam.sdk.transforms.CombineWithContext.Context;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.util.CoderUtils;
 import org.apache.beam.sdk.util.SerializableUtils;
 import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.WindowingStrategy;
 import org.hamcrest.collection.IsMapContaining;
@@ -58,6 +74,9 @@ import org.junit.runners.JUnit4;
 /** Tests for {@link FnApiDoFnRunner}. */
 @RunWith(JUnit4.class)
 public class FnApiDoFnRunnerTest {
+
+  public static final String TEST_PTRANSFORM_ID = "pTransformId";
+
   private static class TestDoFn extends DoFn<String, String> {
     private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput");
     private static final TupleTag<String> additionalOutput = new TupleTag<>("output");
@@ -164,6 +183,216 @@ public class FnApiDoFnRunnerTest {
     mainOutputValues.clear();
   }
 
+  private static class ConcatCombineFn extends CombineFn<String, String, String> {
+    @Override
+    public String createAccumulator() {
+      return "";
+    }
+
+    @Override
+    public String addInput(String accumulator, String input) {
+      return accumulator.concat(input);
+    }
+
+    @Override
+    public String mergeAccumulators(Iterable<String> accumulators) {
+      StringBuilder builder = new StringBuilder();
+      for (String value : accumulators) {
+        builder.append(value);
+      }
+      return builder.toString();
+    }
+
+    @Override
+    public String extractOutput(String accumulator) {
+      return accumulator;
+    }
+  }
+
+  private static class ConcatCombineFnWithContext
+      extends CombineFnWithContext<String, String, String> {
+    @Override
+    public String createAccumulator(Context c) {
+      return "";
+    }
+
+    @Override
+    public String addInput(String accumulator, String input, Context c) {
+      return accumulator.concat(input);
+    }
+
+    @Override
+    public String mergeAccumulators(Iterable<String> accumulators, Context c) {
+      StringBuilder builder = new StringBuilder();
+      for (String value : accumulators) {
+        builder.append(value);
+      }
+      return builder.toString();
+    }
+
+    @Override
+    public String extractOutput(String accumulator, Context c) {
+      return accumulator;
+    }
+  }
+
+  private static class TestStatefulDoFn extends DoFn<KV<String, String>, String> {
+    private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput");
+    private static final TupleTag<String> additionalOutput = new TupleTag<>("output");
+
+    @StateId("value")
+    private final StateSpec<ValueState<String>> valueStateSpec =
+        StateSpecs.value(StringUtf8Coder.of());
+    @StateId("bag")
+    private final StateSpec<BagState<String>> bagStateSpec =
+        StateSpecs.bag(StringUtf8Coder.of());
+    @StateId("combine")
+    private final StateSpec<CombiningState<String, String, String>> combiningStateSpec =
+        StateSpecs.combining(StringUtf8Coder.of(), new ConcatCombineFn());
+    @StateId("combineWithContext")
+    private final StateSpec<CombiningState<String, String, String>> combiningWithContextStateSpec =
+        StateSpecs.combining(StringUtf8Coder.of(), new ConcatCombineFnWithContext());
+
+    @ProcessElement
+    public void processElement(ProcessContext context,
+        @StateId("value") ValueState<String> valueState,
+        @StateId("bag") BagState<String> bagState,
+        @StateId("combine") CombiningState<String, String, String> combiningState,
+        @StateId("combineWithContext")
+            CombiningState<String, String, String> combiningWithContextState) {
+      context.output("value:" + valueState.read());
+      valueState.write(context.element().getValue());
+
+      context.output("bag:" + Iterables.toString(bagState.read()));
+      bagState.add(context.element().getValue());
+
+      context.output("combine:" + combiningState.read());
+      combiningState.add(context.element().getValue());
+
+      context.output("combineWithContext:" + combiningWithContextState.read());
+      combiningWithContextState.add(context.element().getValue());
+    }
+  }
+
+  @Test
+  public void testUsingUserState() throws Exception {
+    String mainOutputId = "101";
+
+    DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn(
+        new TestStatefulDoFn(),
+        WindowingStrategy.globalDefault(),
+        ImmutableList.of(),
+        KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()),
+        Long.parseLong(mainOutputId),
+        ImmutableMap.of(Long.parseLong(mainOutputId), new TupleTag<String>("mainOutput")));
+    RunnerApi.FunctionSpec functionSpec =
+        RunnerApi.FunctionSpec.newBuilder()
+            .setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN)
+            .setPayload(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo)))
+            .build();
+    RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
+        .setSpec(functionSpec)
+        .putInputs("input", "inputTarget")
+        .putOutputs(mainOutputId, "mainOutputTarget")
+        .build();
+
+    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of(
+        key("value", "X"), encode("X0"),
+        key("bag", "X"), encode("X0"),
+        key("combine", "X"), encode("X0"),
+        key("combineWithContext", "X"), encode("X0")
+    ));
+
+    List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+    Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create();
+    consumers.put("mainOutputTarget",
+        (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) mainOutputValues::add);
+    List<ThrowingRunnable> startFunctions = new ArrayList<>();
+    List<ThrowingRunnable> finishFunctions = new ArrayList<>();
+
+    new FnApiDoFnRunner.Factory<>().createRunnerForPTransform(
+        PipelineOptionsFactory.create(),
+        null /* beamFnDataClient */,
+        fakeClient,
+        TEST_PTRANSFORM_ID,
+        pTransform,
+        Suppliers.ofInstance("57L")::get,
+        ImmutableMap.of(),
+        ImmutableMap.of(),
+        consumers,
+        startFunctions::add,
+        finishFunctions::add);
+
+    Iterables.getOnlyElement(startFunctions).run();
+    mainOutputValues.clear();
+
+    assertThat(consumers.keySet(), containsInAnyOrder("inputTarget", "mainOutputTarget"));
+
+    // Ensure that bag user state that is initially empty or populated works.
+    // Ensure that the key order does not matter when we traverse over KV pairs.
+    ThrowingConsumer<WindowedValue<?>> mainInput =
+        Iterables.getOnlyElement(consumers.get("inputTarget"));
+    mainInput.accept(valueInGlobalWindow(KV.of("X", "X1")));
+    mainInput.accept(valueInGlobalWindow(KV.of("Y", "Y1")));
+    mainInput.accept(valueInGlobalWindow(KV.of("X", "X2")));
+    mainInput.accept(valueInGlobalWindow(KV.of("Y", "Y2")));
+    assertThat(mainOutputValues, contains(
+        valueInGlobalWindow("value:X0"),
+        valueInGlobalWindow("bag:[X0]"),
+        valueInGlobalWindow("combine:X0"),
+        valueInGlobalWindow("combineWithContext:X0"),
+        valueInGlobalWindow("value:null"),
+        valueInGlobalWindow("bag:[]"),
+        valueInGlobalWindow("combine:"),
+        valueInGlobalWindow("combineWithContext:"),
+        valueInGlobalWindow("value:X1"),
+        valueInGlobalWindow("bag:[X0, X1]"),
+        valueInGlobalWindow("combine:X0X1"),
+        valueInGlobalWindow("combineWithContext:X0X1"),
+        valueInGlobalWindow("value:Y1"),
+        valueInGlobalWindow("bag:[Y1]"),
+        valueInGlobalWindow("combine:Y1"),
+        valueInGlobalWindow("combineWithContext:Y1")));
+    mainOutputValues.clear();
+
+    Iterables.getOnlyElement(finishFunctions).run();
+    assertThat(mainOutputValues, empty());
+
+    assertEquals(
+        ImmutableMap.<StateKey, ByteString>builder()
+            .put(key("value", "X"), encode("X2"))
+            .put(key("bag", "X"), encode("X0", "X1", "X2"))
+            .put(key("combine", "X"), encode("X0X1X2"))
+            .put(key("combineWithContext", "X"), encode("X0X1X2"))
+            .put(key("value", "Y"), encode("Y2"))
+            .put(key("bag", "Y"), encode("Y1", "Y2"))
+            .put(key("combine", "Y"), encode("Y1Y2"))
+            .put(key("combineWithContext", "Y"), encode("Y1Y2"))
+            .build(),
+        fakeClient.getData());
+    mainOutputValues.clear();
+  }
+
+  /** Produces a {@link StateKey} for the test PTransform id in the Global Window. */
+  private StateKey key(String userStateId, String key) throws IOException {
+    return StateKey.newBuilder().setBagUserState(
+        StateKey.BagUserState.newBuilder()
+            .setPtransformId(TEST_PTRANSFORM_ID)
+            .setUserStateId(userStateId)
+            .setKey(encode(key))
+            .setWindow(ByteString.copyFrom(
+                CoderUtils.encodeToByteArray(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE))))
+        .build();
+  }
+
+  private ByteString encode(String ... values) throws IOException {
+    ByteString.Output out = ByteString.newOutput();
+    for (String value : values) {
+      StringUtf8Coder.of().encode(value, out);
+    }
+    return out.toByteString();
+  }
+
   @Test
   public void testRegistration() {
     for (Registrar registrar :

http://git-wip-us.apache.org/repos/asf/beam/blob/e0f628cc/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/FakeBeamFnStateClient.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/FakeBeamFnStateClient.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/FakeBeamFnStateClient.java
index d260207..60080e1 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/FakeBeamFnStateClient.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/FakeBeamFnStateClient.java
@@ -69,7 +69,7 @@ public class FakeBeamFnStateClient implements BeamFnStateClient {
     switch (request.getRequestCase()) {
       case GET:
         // Chunk gets into 5 byte return blocks
-        ByteString byteString = data.get(request.getStateKey());
+        ByteString byteString = data.getOrDefault(request.getStateKey(), ByteString.EMPTY);
         int block = 0;
         if (request.getGet().getContinuationToken().size() > 0) {
           block = Integer.parseInt(request.getGet().getContinuationToken().toStringUtf8());


[2/2] beam git commit: [BEAM-1347] Create value state, combining state, and bag state views over the BagUserState.

Posted by lc...@apache.org.
[BEAM-1347] Create value state, combining state, and bag state views over the BagUserState.

This closes #3783


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/585440d2
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/585440d2
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/585440d2

Branch: refs/heads/master
Commit: 585440d228db2eae841bc92fa0babd9e131ef839
Parents: f6c8405 e0f628c
Author: Luke Cwik <lc...@google.com>
Authored: Wed Aug 30 14:30:51 2017 -0700
Committer: Luke Cwik <lc...@google.com>
Committed: Wed Aug 30 14:30:51 2017 -0700

----------------------------------------------------------------------
 .../apache/beam/fn/harness/FnApiDoFnRunner.java | 380 ++++++++++++++++++-
 .../beam/fn/harness/FnApiDoFnRunnerTest.java    | 229 +++++++++++
 .../fn/harness/state/FakeBeamFnStateClient.java |   2 +-
 3 files changed, 605 insertions(+), 6 deletions(-)
----------------------------------------------------------------------