You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bc...@apache.org on 2017/04/24 18:57:15 UTC

[1/2] beam git commit: [BEAM-1148] Port PAssert away from aggregators

Repository: beam
Updated Branches:
  refs/heads/master 37e532188 -> 7a2fe68fd


[BEAM-1148] Port PAssert away from aggregators

Separates evaluation of the assertion into a transform that outputs
`SuccessOrFailure` from the reporting of failures. The latter happens in
a separate composite transform making it possible to override the
implementation.

Introduces a default implementation that uses Metrics to count the
number of successfully executed assertions as well as the number of
failing assertions.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/e8f0922f
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/e8f0922f
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/e8f0922f

Branch: refs/heads/master
Commit: e8f0922f6f15dd3ab96f54ba4a5c1083269d70bd
Parents: 37e5321
Author: Pablo <pa...@google.com>
Authored: Wed Mar 29 14:49:53 2017 -0700
Committer: bchambers <bc...@google.com>
Committed: Mon Apr 24 11:56:58 2017 -0700

----------------------------------------------------------------------
 .../beam/runners/spark/TestSparkRunner.java     |  35 ++--
 .../ResumeFromCheckpointStreamingTest.java      |  72 ++++----
 .../org/apache/beam/sdk/testing/PAssert.java    | 164 ++++++++++++-------
 .../beam/sdk/testing/SuccessOrFailure.java      |  82 ++++++++++
 .../apache/beam/sdk/testing/PAssertTest.java    |  55 +++++++
 5 files changed, 307 insertions(+), 101 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/e8f0922f/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
index 61fcaa9..10e98b8 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
@@ -40,6 +40,9 @@ import org.apache.beam.runners.spark.util.GlobalWatermarkHolder;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.PipelineResult;
 import org.apache.beam.sdk.io.BoundedReadFromUnboundedSource;
+import org.apache.beam.sdk.metrics.MetricNameFilter;
+import org.apache.beam.sdk.metrics.MetricResult;
+import org.apache.beam.sdk.metrics.MetricsFilter;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsValidator;
 import org.apache.beam.sdk.runners.PTransformOverride;
@@ -136,11 +139,15 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> {
             isOneOf(PipelineResult.State.STOPPED, PipelineResult.State.DONE));
 
         // validate assertion succeeded (at least once).
-        int successAssertions = 0;
-        try {
-          successAssertions = result.getAggregatorValue(PAssert.SUCCESS_COUNTER, Integer.class);
-        } catch (NullPointerException e) {
-          // No assertions registered will cause an NPE here.
+        long successAssertions = 0;
+        Iterable<MetricResult<Long>> counterResults = result.metrics().queryMetrics(
+            MetricsFilter.builder()
+                .addNameFilter(MetricNameFilter.named(PAssert.class, PAssert.SUCCESS_COUNTER))
+                .build()).counters();
+        for (MetricResult<Long> counter : counterResults) {
+          if (counter.attempted().longValue() > 0) {
+            successAssertions++;
+          }
         }
         Integer expectedAssertions = testSparkPipelineOptions.getExpectedAssertions() != null
             ? testSparkPipelineOptions.getExpectedAssertions() : expectedNumberOfAssertions;
@@ -149,18 +156,22 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> {
                 "Expected %d successful assertions, but found %d.",
                 expectedAssertions, successAssertions),
             successAssertions,
-            is(expectedAssertions));
+            is(expectedAssertions.longValue()));
         // validate assertion didn't fail.
-        int failedAssertions = 0;
-        try {
-          failedAssertions = result.getAggregatorValue(PAssert.FAILURE_COUNTER, Integer.class);
-        } catch (NullPointerException e) {
-          // No assertions registered will cause an NPE here.
+        long failedAssertions = 0;
+        Iterable<MetricResult<Long>> failCounterResults = result.metrics().queryMetrics(
+            MetricsFilter.builder()
+                .addNameFilter(MetricNameFilter.named(PAssert.class, PAssert.FAILURE_COUNTER))
+                .build()).counters();
+        for (MetricResult<Long> counter : failCounterResults) {
+          if (counter.attempted().longValue() > 0) {
+            failedAssertions++;
+          }
         }
         assertThat(
             String.format("Found %d failed assertions.", failedAssertions),
             failedAssertions,
-            is(0));
+            is(0L));
 
         LOG.info(
             String.format(

http://git-wip-us.apache.org/repos/asf/beam/blob/e8f0922f/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
index 6cbf83a..1aa76a3 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
@@ -19,7 +19,6 @@ package org.apache.beam.runners.spark.translation.streaming;
 
 import static org.apache.beam.sdk.metrics.MetricMatchers.attemptedMetricsResult;
 import static org.hamcrest.Matchers.containsInAnyOrder;
-import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasItem;
 import static org.hamcrest.Matchers.is;
 import static org.junit.Assert.assertThat;
@@ -51,10 +50,10 @@ import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.io.kafka.KafkaIO;
 import org.apache.beam.sdk.metrics.Counter;
 import org.apache.beam.sdk.metrics.MetricNameFilter;
+import org.apache.beam.sdk.metrics.MetricResult;
 import org.apache.beam.sdk.metrics.Metrics;
 import org.apache.beam.sdk.metrics.MetricsFilter;
 import org.apache.beam.sdk.testing.PAssert;
-import org.apache.beam.sdk.transforms.Aggregator;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
@@ -62,7 +61,6 @@ import org.apache.beam.sdk.transforms.Keys;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.SerializableFunction;
-import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.transforms.Values;
 import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.transforms.WithKeys;
@@ -94,8 +92,8 @@ import org.junit.experimental.categories.Category;
  * <p>Runs the pipeline reading from a Kafka backlog with a WM function that will move to infinity
  * on a EOF signal.
  * After resuming from checkpoint, a single output (guaranteed by the WM) is asserted, along with
- * {@link Aggregator}s and {@link Metrics} values that are expected to resume from previous count
- * and a side-input that is expected to recover as well.
+ * {@link Metrics} values that are expected to resume from previous count and a side-input that is
+ * expected to recover as well.
  */
 public class ResumeFromCheckpointStreamingTest {
   private static final EmbeddedKafkaCluster.EmbeddedZookeeper EMBEDDED_ZOOKEEPER =
@@ -161,16 +159,13 @@ public class ResumeFromCheckpointStreamingTest {
 
     // first run should expect EOT matching the last injected element.
     SparkPipelineResult res = run(pipelineRule, Optional.of(new Instant(400)), 0);
-    // assertions 1:
-    long processedMessages1 = res.getAggregatorValue("processedMessages", Long.class);
-    assertThat(
-        String.format(
-            "Expected %d processed messages count but found %d", 4, processedMessages1),
-        processedMessages1,
-        equalTo(4L));
+
     assertThat(res.metrics().queryMetrics(metricsFilter).counters(),
         hasItem(attemptedMetricsResult(ResumeFromCheckpointStreamingTest.class.getName(),
             "allMessages", "EOFShallNotPassFn", 4L)));
+    assertThat(res.metrics().queryMetrics(metricsFilter).counters(),
+        hasItem(attemptedMetricsResult(ResumeFromCheckpointStreamingTest.class.getName(),
+            "processedMessages", "EOFShallNotPassFn", 4L)));
 
     //--- between executions:
 
@@ -186,27 +181,42 @@ public class ResumeFromCheckpointStreamingTest {
     // recovery should resume from last read offset, and read the second batch of input.
     res = runAgain(pipelineRule, 1);
     // assertions 2:
-    long processedMessages2 = res.getAggregatorValue("processedMessages", Long.class);
-    assertThat(
-        String.format("Expected %d processed messages count but found %d", 5, processedMessages2),
-        processedMessages2,
-        equalTo(5L));
+    assertThat(res.metrics().queryMetrics(metricsFilter).counters(),
+        hasItem(attemptedMetricsResult(ResumeFromCheckpointStreamingTest.class.getName(),
+            "processedMessages", "EOFShallNotPassFn", 5L)));
     assertThat(res.metrics().queryMetrics(metricsFilter).counters(),
         hasItem(attemptedMetricsResult(ResumeFromCheckpointStreamingTest.class.getName(),
             "allMessages", "EOFShallNotPassFn", 6L)));
-    int successAssertions = res.getAggregatorValue(PAssert.SUCCESS_COUNTER, Integer.class);
-    res.getAggregatorValue(PAssert.SUCCESS_COUNTER, Integer.class);
+    long successAssertions = 0;
+    Iterable<MetricResult<Long>> counterResults = res.metrics().queryMetrics(
+        MetricsFilter.builder()
+            .addNameFilter(MetricNameFilter.named(PAssert.class, PAssert.SUCCESS_COUNTER))
+            .build()).counters();
+    for (MetricResult<Long> counter : counterResults) {
+      if (counter.attempted().longValue() > 0) {
+        successAssertions++;
+      }
+    }
     assertThat(
         String.format(
-            "Expected %d successful assertions, but found %d.", 1, successAssertions),
+            "Expected %d successful assertions, but found %d.", 1L, successAssertions),
             successAssertions,
-            is(1));
+            is(1L));
     // validate assertion didn't fail.
-    int failedAssertions = res.getAggregatorValue(PAssert.FAILURE_COUNTER, Integer.class);
+    long failedAssertions = 0;
+    Iterable<MetricResult<Long>> failCounterResults = res.metrics().queryMetrics(
+        MetricsFilter.builder()
+            .addNameFilter(MetricNameFilter.named(PAssert.class, PAssert.FAILURE_COUNTER))
+            .build()).counters();
+    for (MetricResult<Long> counter : failCounterResults) {
+      if (counter.attempted().longValue() > 0) {
+        failedAssertions++;
+      }
+    }
     assertThat(
         String.format("Found %d failed assertions.", failedAssertions),
         failedAssertions,
-        is(0));
+        is(0L));
 
   }
 
@@ -289,8 +299,8 @@ public class ResumeFromCheckpointStreamingTest {
   /** A pass-through fn that prevents EOF event from passing. */
   private static class EOFShallNotPassFn extends DoFn<String, String> {
     final PCollectionView<List<String>> view;
-    private final Aggregator<Long, Long> aggregator =
-        createAggregator("processedMessages", Sum.ofLongs());
+    private final Counter aggregator = Metrics.counter(
+        ResumeFromCheckpointStreamingTest.class, "processedMessages");
     Counter counter =
         Metrics.counter(ResumeFromCheckpointStreamingTest.class, "allMessages");
 
@@ -305,7 +315,7 @@ public class ResumeFromCheckpointStreamingTest {
       assertThat(c.sideInput(view), containsInAnyOrder("side1", "side2"));
       counter.inc();
       if (!element.equals("EOF")) {
-        aggregator.addValue(1L);
+        aggregator.inc();
         c.output(c.element());
       }
     }
@@ -330,10 +340,8 @@ public class ResumeFromCheckpointStreamingTest {
     }
 
     private static class AssertDoFn<T> extends DoFn<Iterable<T>, Void> {
-      private final Aggregator<Integer, Integer> success =
-          createAggregator(PAssert.SUCCESS_COUNTER, Sum.ofIntegers());
-      private final Aggregator<Integer, Integer> failure =
-          createAggregator(PAssert.FAILURE_COUNTER, Sum.ofIntegers());
+      private final Counter success = Metrics.counter(PAssert.class, PAssert.SUCCESS_COUNTER);
+      private final Counter failure = Metrics.counter(PAssert.class, PAssert.FAILURE_COUNTER);
       private final T[] expected;
 
       AssertDoFn(T[] expected) {
@@ -344,9 +352,9 @@ public class ResumeFromCheckpointStreamingTest {
       public void processElement(ProcessContext c) throws Exception {
         try {
           assertThat(c.element(), containsInAnyOrder(expected));
-          success.addValue(1);
+          success.inc();
         } catch (Throwable t) {
-          failure.addValue(1);
+          failure.inc();
           throw t;
         }
       }

http://git-wip-us.apache.org/repos/asf/beam/blob/e8f0922f/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java
index 92dca53..85b8c5f 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java
@@ -40,9 +40,10 @@ import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.MapCoder;
 import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.metrics.Counter;
+import org.apache.beam.sdk.metrics.Metrics;
 import org.apache.beam.sdk.runners.PipelineRunner;
 import org.apache.beam.sdk.runners.TransformHierarchy.Node;
-import org.apache.beam.sdk.transforms.Aggregator;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.Flatten;
@@ -52,7 +53,6 @@ import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.SimpleFunction;
-import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.transforms.Values;
 import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.transforms.WithKeys;
@@ -107,9 +107,12 @@ import org.slf4j.LoggerFactory;
 public class PAssert {
 
   private static final Logger LOG = LoggerFactory.getLogger(PAssert.class);
-
   public static final String SUCCESS_COUNTER = "PAssertSuccess";
   public static final String FAILURE_COUNTER = "PAssertFailure";
+  private static final Counter successCounter = Metrics.counter(
+      PAssert.class, PAssert.SUCCESS_COUNTER);
+  private static final Counter failureCounter = Metrics.counter(
+      PAssert.class, PAssert.FAILURE_COUNTER);
 
   private static int assertCount = 0;
 
@@ -121,6 +124,79 @@ public class PAssert {
   private PAssert() {}
 
   /**
+   * A {@link DoFn} that counts the number of successful {@link SuccessOrFailure} in the
+   * input {@link PCollection} and counts them. If a failed {@link SuccessOrFailure} is
+   * encountered, it is counted and immediately raised.
+   */
+  private static final class DefaultConcludeFn extends DoFn<SuccessOrFailure, Void> {
+
+    @ProcessElement
+    public void processElement(ProcessContext c) {
+      SuccessOrFailure e = c.element();
+      if (e.isSuccess()) {
+        PAssert.successCounter.inc();
+      } else {
+        PAssert.failureCounter.inc();
+        throw e.assertionError();
+      }
+    }
+  }
+
+  /**
+   * Default transform to check that a PAssert was successful. This transform
+   * relies on two {@link Counter} objects from the Metrics API to count the number of
+   * successful and failed asserts.
+   * Runners that do not support the Metrics API should replace this transform with
+   * their own implementation.
+   */
+  public static class DefaultConcludeTransform
+      extends PTransform<PCollection<SuccessOrFailure>, PCollection<Void>> {
+    public PCollection<Void> expand(PCollection<SuccessOrFailure> input) {
+      return input.apply(ParDo.of(new DefaultConcludeFn()));
+    }
+  }
+
+  /**
+   * Track the place where an assertion is defined.
+   * This is necessary because the stack trace of a Throwable is a transient attribute, and can't
+   * be serialized. {@link PAssertionSite} helps track the stack trace
+   * of the place where an assertion is issued.
+   */
+  public static class PAssertionSite implements Serializable {
+    private final String message;
+    private final StackTraceElement[] creationStackTrace;
+
+    static PAssertionSite capture(String message) {
+      return new PAssertionSite(message, new Throwable().getStackTrace());
+    }
+
+    PAssertionSite() {
+      this(null, new StackTraceElement[0]);
+    }
+
+    PAssertionSite(String message, StackTraceElement[] creationStackTrace) {
+      this.message = message;
+      this.creationStackTrace = creationStackTrace;
+    }
+
+    public AssertionError wrap(Throwable t) {
+      AssertionError res =
+          new AssertionError(
+              message.isEmpty() ? t.getMessage() : (message + ": " + t.getMessage()), t);
+      res.setStackTrace(creationStackTrace);
+      return res;
+    }
+
+    public AssertionError wrap(String message) {
+      String outputMessage = (this.message == null || this.message.isEmpty())
+          ? message : (this.message + ": " + message);
+      AssertionError res = new AssertionError(outputMessage);
+      res.setStackTrace(creationStackTrace);
+      return res;
+    }
+  }
+
+  /**
    * Builder interface for assertions applicable to iterables and PCollection contents.
    */
   public interface IterableAssert<T> {
@@ -400,33 +476,11 @@ public class PAssert {
 
   ////////////////////////////////////////////////////////////
 
-  private static class PAssertionSite implements Serializable {
-    private final String message;
-    private final StackTraceElement[] creationStackTrace;
-
-    static PAssertionSite capture(String message) {
-      return new PAssertionSite(message, new Throwable().getStackTrace());
-    }
-
-    PAssertionSite(String message, StackTraceElement[] creationStackTrace) {
-      this.message = message;
-      this.creationStackTrace = creationStackTrace;
-    }
-
-    public AssertionError wrap(Throwable t) {
-      AssertionError res =
-          new AssertionError(
-              message.isEmpty() ? t.getMessage() : (message + ": " + t.getMessage()), t);
-      res.setStackTrace(creationStackTrace);
-      return res;
-    }
-  }
-
   /**
    * An {@link IterableAssert} about the contents of a {@link PCollection}. This does not require
    * the runner to support side inputs.
    */
-  private static class PCollectionContentsAssert<T> implements IterableAssert<T> {
+  protected static class PCollectionContentsAssert<T> implements IterableAssert<T> {
     private final PCollection<T> actual;
     private final AssertionWindows rewindowingStrategy;
     private final SimpleFunction<Iterable<ValueInSingleWindow<T>>, Iterable<T>> paneExtractor;
@@ -560,7 +614,8 @@ public class PAssert {
       return this;
     }
 
-    private static class MatcherCheckerFn<T> implements SerializableFunction<T, Void> {
+    /** Check that the passed-in matchers match the existing data. */
+    protected static class MatcherCheckerFn<T> implements SerializableFunction<T, Void> {
       private SerializableMatcher<T> matcher;
 
       public MatcherCheckerFn(SerializableMatcher<T> matcher) {
@@ -690,7 +745,8 @@ public class PAssert {
         SerializableFunction<Iterable<T>, Void> checkerFn) {
       actual.apply(
           "PAssert$" + (assertCount++),
-          new GroupThenAssertForSingleton<>(checkerFn, rewindowingStrategy, paneExtractor, site));
+          new GroupThenAssertForSingleton<>(
+              checkerFn, rewindowingStrategy, paneExtractor, site));
       return this;
     }
 
@@ -1033,7 +1089,8 @@ public class PAssert {
           .apply("GroupGlobally", new GroupGlobally<T>(rewindowingStrategy))
           .apply("GetPane", MapElements.via(paneExtractor))
           .setCoder(IterableCoder.of(input.getCoder()))
-          .apply("RunChecks", ParDo.of(new GroupedValuesCheckerDoFn<>(checkerFn, site)));
+          .apply("RunChecks", ParDo.of(new GroupedValuesCheckerDoFn<>(checkerFn, site)))
+          .apply("VerifyAssertions", new DefaultConcludeTransform());
 
       return PDone.in(input.getPipeline());
     }
@@ -1069,7 +1126,8 @@ public class PAssert {
           .apply("GroupGlobally", new GroupGlobally<Iterable<T>>(rewindowingStrategy))
           .apply("GetPane", MapElements.via(paneExtractor))
           .setCoder(IterableCoder.of(input.getCoder()))
-          .apply("RunChecks", ParDo.of(new SingletonCheckerDoFn<>(checkerFn, site)));
+          .apply("RunChecks", ParDo.of(new SingletonCheckerDoFn<>(checkerFn, site)))
+          .apply("VerifyAssertions", new DefaultConcludeTransform());
 
       return PDone.in(input.getPipeline());
     }
@@ -1112,8 +1170,8 @@ public class PAssert {
           .apply("WindowToken", windowToken)
           .apply(
               "RunChecks",
-              ParDo.of(new SideInputCheckerDoFn<>(checkerFn, actual, site)).withSideInputs(actual));
-
+              ParDo.of(new SideInputCheckerDoFn<>(checkerFn, actual, site)).withSideInputs(actual))
+          .apply("VerifyAssertions", new DefaultConcludeTransform());
       return PDone.in(input.getPipeline());
     }
   }
@@ -1125,12 +1183,8 @@ public class PAssert {
    * <p>The input is ignored, but is {@link Integer} to be usable on runners that do not support
    * null values.
    */
-  private static class SideInputCheckerDoFn<ActualT> extends DoFn<Integer, Void> {
+  private static class SideInputCheckerDoFn<ActualT> extends DoFn<Integer, SuccessOrFailure> {
     private final SerializableFunction<ActualT, Void> checkerFn;
-    private final Aggregator<Integer, Integer> success =
-        createAggregator(SUCCESS_COUNTER, Sum.ofIntegers());
-    private final Aggregator<Integer, Integer> failure =
-        createAggregator(FAILURE_COUNTER, Sum.ofIntegers());
     private final PCollectionView<ActualT> actual;
     private final PAssertionSite site;
 
@@ -1146,7 +1200,7 @@ public class PAssert {
     @ProcessElement
     public void processElement(ProcessContext c) {
       ActualT actualContents = c.sideInput(actual);
-      doChecks(site, actualContents, checkerFn, success, failure);
+      c.output(doChecks(site, actualContents, checkerFn));
     }
   }
 
@@ -1157,12 +1211,8 @@ public class PAssert {
    *
    * <p>The singleton property is presumed, not enforced.
    */
-  private static class GroupedValuesCheckerDoFn<ActualT> extends DoFn<ActualT, Void> {
+  private static class GroupedValuesCheckerDoFn<ActualT> extends DoFn<ActualT, SuccessOrFailure> {
     private final SerializableFunction<ActualT, Void> checkerFn;
-    private final Aggregator<Integer, Integer> success =
-        createAggregator(SUCCESS_COUNTER, Sum.ofIntegers());
-    private final Aggregator<Integer, Integer> failure =
-        createAggregator(FAILURE_COUNTER, Sum.ofIntegers());
     private final PAssertionSite site;
 
     private GroupedValuesCheckerDoFn(
@@ -1173,7 +1223,11 @@ public class PAssert {
 
     @ProcessElement
     public void processElement(ProcessContext c) {
-      doChecks(site, c.element(), checkerFn, success, failure);
+      try {
+        c.output(doChecks(site, c.element(), checkerFn));
+      } catch (Throwable t) {
+        throw t;
+      }
     }
   }
 
@@ -1185,12 +1239,9 @@ public class PAssert {
    * <p>The singleton property of the input {@link PCollection} is presumed, not enforced. However,
    * each input element must be a singleton iterable, or this will fail.
    */
-  private static class SingletonCheckerDoFn<ActualT> extends DoFn<Iterable<ActualT>, Void> {
+  private static class SingletonCheckerDoFn<ActualT>
+      extends DoFn<Iterable<ActualT>, SuccessOrFailure> {
     private final SerializableFunction<ActualT, Void> checkerFn;
-    private final Aggregator<Integer, Integer> success =
-        createAggregator(SUCCESS_COUNTER, Sum.ofIntegers());
-    private final Aggregator<Integer, Integer> failure =
-        createAggregator(FAILURE_COUNTER, Sum.ofIntegers());
     private final PAssertionSite site;
 
     private SingletonCheckerDoFn(
@@ -1202,22 +1253,21 @@ public class PAssert {
     @ProcessElement
     public void processElement(ProcessContext c) {
       ActualT actualContents = Iterables.getOnlyElement(c.element());
-      doChecks(site, actualContents, checkerFn, success, failure);
+      c.output(doChecks(site, actualContents, checkerFn));
     }
   }
 
-  private static <ActualT> void doChecks(
+  protected static <ActualT> SuccessOrFailure doChecks(
       PAssertionSite site,
       ActualT actualContents,
-      SerializableFunction<ActualT, Void> checkerFn,
-      Aggregator<Integer, Integer> successAggregator,
-      Aggregator<Integer, Integer> failureAggregator) {
+      SerializableFunction<ActualT, Void> checkerFn) {
+    SuccessOrFailure result = SuccessOrFailure.success();
     try {
       checkerFn.apply(actualContents);
-      successAggregator.addValue(1);
     } catch (Throwable t) {
-      failureAggregator.addValue(1);
-      throw site.wrap(t);
+      result = SuccessOrFailure.failure(site, t.getMessage());
+    } finally {
+      return result;
     }
   }
 

http://git-wip-us.apache.org/repos/asf/beam/blob/e8f0922f/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/SuccessOrFailure.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/SuccessOrFailure.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/SuccessOrFailure.java
new file mode 100644
index 0000000..04e3c35
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/SuccessOrFailure.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.testing;
+
+import com.google.common.base.MoreObjects;
+import java.io.Serializable;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.coders.DefaultCoder;
+import org.apache.beam.sdk.coders.SerializableCoder;
+
+/**
+ * Output of {@link PAssert}. Passed to a conclude function to act upon.
+ */
+@DefaultCoder(SerializableCoder.class)
+public final class SuccessOrFailure implements Serializable {
+  // TODO Add a SerializableThrowable. instead of relying on PAssertionSite.(BEAM-1898)
+
+  private final boolean isSuccess;
+  @Nullable
+  private final PAssert.PAssertionSite site;
+  @Nullable
+  private final String message;
+
+  private SuccessOrFailure() {
+    this(true, null, null);
+  }
+
+  private SuccessOrFailure(
+      boolean isSuccess,
+      @Nullable PAssert.PAssertionSite site,
+      @Nullable String message) {
+    this.isSuccess = isSuccess;
+    this.site = site;
+    this.message = message;
+  }
+
+  public boolean isSuccess() {
+    return isSuccess;
+  }
+
+  @Nullable
+  public AssertionError assertionError() {
+    return  site == null ? null : site.wrap(message);
+  }
+
+  public static SuccessOrFailure success() {
+    return new SuccessOrFailure(true, null, null);
+  }
+
+  public static SuccessOrFailure failure(@Nullable PAssert.PAssertionSite site,
+      @Nullable String message) {
+    return new SuccessOrFailure(false, site, message);
+  }
+
+  public static SuccessOrFailure failure(@Nullable PAssert.PAssertionSite site) {
+    return new SuccessOrFailure(false, site, null);
+  }
+
+  @Override
+  public String toString() {
+    return MoreObjects.toStringHelper(this)
+        .add("isSuccess", isSuccess())
+        .addValue(message)
+        .omitNullValues()
+        .toString();
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/e8f0922f/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java
index 9d580e4..2ef892c 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java
@@ -20,6 +20,7 @@ package org.apache.beam.sdk.testing;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
@@ -36,8 +37,10 @@ import java.util.regex.Pattern;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.AtomicCoder;
 import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.coders.SerializableCoder;
 import org.apache.beam.sdk.coders.VarLongCoder;
 import org.apache.beam.sdk.io.GenerateSequence;
+import org.apache.beam.sdk.testing.PAssert.PCollectionContentsAssert.MatcherCheckerFn;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.Sum;
@@ -46,6 +49,7 @@ import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
 import org.apache.beam.sdk.transforms.windowing.SlidingWindows;
 import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.util.CoderUtils;
 import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
@@ -117,6 +121,44 @@ public class PAssertTest implements Serializable {
     }
   }
 
+  @Test
+  public void testFailureEncodedDecoded() throws IOException {
+    AssertionError error = null;
+    try {
+      assertEquals(0, 1);
+    } catch (AssertionError e) {
+      error = e;
+    }
+    SuccessOrFailure failure = SuccessOrFailure.failure(
+        new PAssert.PAssertionSite(error.getMessage(), error.getStackTrace()));
+    SerializableCoder<SuccessOrFailure> coder = SerializableCoder.of(SuccessOrFailure.class);
+
+    byte[] encoded = CoderUtils.encodeToByteArray(coder, failure);
+    SuccessOrFailure res = CoderUtils.decodeFromByteArray(coder, encoded);
+
+    // Should compare strings, because throwables are not directly comparable.
+    assertEquals("Encode-decode failed SuccessOrFailure",
+        failure.assertionError().toString(), res.assertionError().toString());
+    String resultStacktrace = Throwables.getStackTraceAsString(res.assertionError());
+    String failureStacktrace = Throwables.getStackTraceAsString(failure.assertionError());
+    assertThat(resultStacktrace, is(failureStacktrace));
+  }
+
+  @Test
+  public void testSuccessEncodedDecoded() throws IOException {
+    SuccessOrFailure success = SuccessOrFailure.success();
+    SerializableCoder<SuccessOrFailure> coder = SerializableCoder.of(SuccessOrFailure.class);
+
+    byte[] encoded = CoderUtils.encodeToByteArray(coder, success);
+    SuccessOrFailure res = CoderUtils.decodeFromByteArray(coder, encoded);
+
+    assertEquals("Encode-decode successful SuccessOrFailure",
+        success.isSuccess(), res.isSuccess());
+    assertEquals("Encode-decode successful SuccessOrFailure",
+        success.assertionError(),
+        res.assertionError());
+  }
+
   /**
    * A {@link PAssert} about the contents of a {@link PCollection}
    * must not require the contents of the {@link PCollection} to be
@@ -452,6 +494,19 @@ public class PAssertTest implements Serializable {
   }
 
   @Test
+  public void testAssertionSiteIsCaptured() {
+    // This check should return a failure.
+    SuccessOrFailure res = PAssert.doChecks(
+        PAssert.PAssertionSite.capture("Captured assertion message."),
+        new Integer(10),
+        new MatcherCheckerFn(SerializableMatchers.contains(new Integer(11))));
+
+    String stacktrace = Throwables.getStackTraceAsString(res.assertionError());
+    assertEquals(res.isSuccess(), false);
+    assertThat(stacktrace, containsString("PAssertionSite.capture"));
+  }
+
+  @Test
   @Category(ValidatesRunner.class)
   public void testAssertionSiteIsCapturedWithMessage() throws Exception {
     PCollection<Long> vals = pipeline.apply(GenerateSequence.from(0).to(5));


[2/2] beam git commit: Closes #2417

Posted by bc...@apache.org.
Closes #2417


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/7a2fe68f
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/7a2fe68f
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/7a2fe68f

Branch: refs/heads/master
Commit: 7a2fe68fd37bf60b357d9894ba0c365892154bcb
Parents: 37e5321 e8f0922
Author: bchambers <bc...@google.com>
Authored: Mon Apr 24 11:17:08 2017 -0700
Committer: bchambers <bc...@google.com>
Committed: Mon Apr 24 11:56:59 2017 -0700

----------------------------------------------------------------------
 .../beam/runners/spark/TestSparkRunner.java     |  35 ++--
 .../ResumeFromCheckpointStreamingTest.java      |  72 ++++----
 .../org/apache/beam/sdk/testing/PAssert.java    | 164 ++++++++++++-------
 .../beam/sdk/testing/SuccessOrFailure.java      |  82 ++++++++++
 .../apache/beam/sdk/testing/PAssertTest.java    |  55 +++++++
 5 files changed, 307 insertions(+), 101 deletions(-)
----------------------------------------------------------------------