You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by GitBox <gi...@apache.org> on 2019/01/15 06:26:41 UTC
[beam] Diff for: [GitHub] kennknowles merged pull request #7506: [BEAM-6430]
Fix EXCEPT
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSetOperatorsTransforms.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSetOperatorsTransforms.java
index 581fb08e83f8..4827e2d579f8 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSetOperatorsTransforms.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSetOperatorsTransforms.java
@@ -25,6 +25,7 @@
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterators;
/** Collections of {@code PTransform} and {@code DoFn} used to perform Set operations. */
public abstract class BeamSetOperatorsTransforms {
@@ -89,6 +90,8 @@ public void processElement(ProcessContext ctx) {
}
break;
case MINUS:
+ // Say for Row R, there are m instances on left and n instances on right,
+ // EXCEPT ALL outputs MAX(m - n, 0) instances of R.
if (leftRows.iterator().hasNext() && !rightRows.iterator().hasNext()) {
Iterator<Row> iter = leftRows.iterator();
if (all) {
@@ -100,6 +103,21 @@ public void processElement(ProcessContext ctx) {
// only output one
ctx.output(iter.next());
}
+ } else if (leftRows.iterator().hasNext() && rightRows.iterator().hasNext()) {
+ int leftCount = Iterators.size(leftRows.iterator());
+ int rightCount = Iterators.size(rightRows.iterator());
+
+ int outputCount = leftCount - rightCount;
+ if (outputCount > 0) {
+ if (all) {
+ while (outputCount > 0) {
+ outputCount--;
+ ctx.output(ctx.element().getKey());
+ }
+ } else {
+ ctx.output(ctx.element().getKey());
+ }
+ }
}
}
}
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamMinusRelTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamMinusRelTest.java
index 322f8fb88712..5ac5bce878f0 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamMinusRelTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamMinusRelTest.java
@@ -93,7 +93,7 @@ public void testExcept() throws Exception {
Schema.FieldType.INT64, "order_id",
Schema.FieldType.INT32, "site_id",
Schema.FieldType.DECIMAL, "price")
- .addRows(4L, 4, new BigDecimal(4.0))
+ .addRows(1L, 1, new BigDecimal(1.0), 4L, 4, new BigDecimal(4.0))
.getRows());
pipeline.run();
@@ -110,7 +110,7 @@ public void testExceptAll() throws Exception {
+ "FROM ORDER_DETAILS2 ";
PCollection<Row> rows = compilePipeline(sql, pipeline);
- PAssert.that(rows).satisfies(new CheckSize(2));
+ PAssert.that(rows).satisfies(new CheckSize(3));
PAssert.that(rows)
.containsInAnyOrder(
@@ -118,7 +118,16 @@ public void testExceptAll() throws Exception {
Schema.FieldType.INT64, "order_id",
Schema.FieldType.INT32, "site_id",
Schema.FieldType.DECIMAL, "price")
- .addRows(4L, 4, new BigDecimal(4.0), 4L, 4, new BigDecimal(4.0))
+ .addRows(
+ 1L,
+ 1,
+ new BigDecimal(1.0),
+ 4L,
+ 4,
+ new BigDecimal(4.0),
+ 4L,
+ 4,
+ new BigDecimal(4.0))
.getRows());
pipeline.run();
With regards,
Apache Git Services