You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by re...@apache.org on 2019/03/01 15:05:57 UTC
[beam] branch master updated: Merge pull request #7353: [BEAM-4461]
Support inner and outer style joins in CoGroup.
This is an automated email from the ASF dual-hosted git repository.
reuvenlax pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new c41b3c0 Merge pull request #7353: [BEAM-4461] Support inner and outer style joins in CoGroup.
c41b3c0 is described below
commit c41b3c082d924059c393345f4b1e740804d2b877
Author: reuvenlax <re...@google.com>
AuthorDate: Fri Mar 1 07:05:46 2019 -0800
Merge pull request #7353: [BEAM-4461] Support inner and outer style joins in CoGroup.
---
.../beam/sdk/schemas/transforms/CoGroup.java | 612 ++++++++++++++++-----
.../org/apache/beam/sdk/transforms/Create.java | 13 +
.../beam/sdk/transforms/join/CoGbkResult.java | 17 +
.../sdk/transforms/join/KeyedPCollectionTuple.java | 20 +
.../apache/beam/sdk/values/PCollectionTuple.java | 94 ++++
.../main/java/org/apache/beam/sdk/values/Row.java | 4 +-
.../beam/sdk/schemas/transforms/CoGroupTest.java | 352 +++++++++++-
7 files changed, 932 insertions(+), 180 deletions(-)
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java
index ebcf418..3822821 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java
@@ -17,11 +17,14 @@
*/
package org.apache.beam.sdk.schemas.transforms;
+import com.google.auto.value.AutoValue;
+import java.io.Serializable;
import java.util.Collections;
-import java.util.Comparator;
import java.util.List;
import java.util.Map;
+import java.util.Set;
import java.util.TreeMap;
+import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.beam.sdk.coders.KvCoder;
@@ -33,6 +36,7 @@ import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.transforms.join.CoGbkResult;
import org.apache.beam.sdk.transforms.join.CoGroupByKey;
import org.apache.beam.sdk.transforms.join.KeyedPCollectionTuple;
@@ -57,14 +61,9 @@ import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Maps;
* <p>For example, the following demonstrates joining three PCollections on the "user" and "country"
* fields:
*
- * <pre>{@code TupleTag<Input1Type> input1Tag = new TupleTag<>("input1");
- * TupleTag<Input2Type> input2Tag = new TupleTag<>("input2");
- * TupleTag<Input3Type> input3Tag = new TupleTag<>("input3");
- * PCollection<KV<Row, Row>> joined = PCollectionTuple
- * .of(input1Tag, input1)
- * .and(input2Tag, input2)
- * .and(input3Tag, input3)
- * .apply(CoGroup.byFieldNames("user", "country"));
+ * <pre>{@code PCollection<KV<Row, Row>> joined =
+ * PCollectionTuple.of("input1", input1, "input2", input2, "input3", input3)
+ * .apply(CoGroup.join(By.fieldNames("user", "country")));
* }</pre>
*
* <p>In the above case, the key schema will contain the two string fields "user" and "country"; in
@@ -107,156 +106,251 @@ import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Maps;
* those fields match. In this case, fields must be specified for every input PCollection. For
* example:
*
- * <pre>{@code PCollection<KV<Row, Row>> joined = PCollectionTuple
- * .of(input1Tag, input1)
- * .and(input2Tag, input2)
+ * <pre>{@code PCollection<KV<Row, Row>> joined
+ * = PCollectionTuple.of("input1Tag", input1, "input2Tag", input2)
* .apply(CoGroup
- * .byFieldNames(input1Tag, "referringUser"))
- * .byFieldNames(input2Tag, "user"));
+ * .join("input1Tag", By.fieldNames("referringUser")))
+ * .join("input2Tag", By.fieldNames("user")));
* }</pre>
+ *
+ * <p>Traditional (SQL) joins are cross-product joins. All rows that match the join condition are
+ * combined into individual rows and returned; in fact any SQL inner joins is a subset of the
+ * cross-product of two tables. This transform also supports the same functionality using the {@link
+ * Inner#crossProductJoin()} method.
+ *
+ * <p>For example, consider the SQL join: SELECT * FROM input1 INNER JOIN input2 ON input1.user =
+ * input2.user
+ *
+ * <p>You could express this with:
+ *
+ * <pre>{@code
+ * PCollection<Row> joined = PCollectionTuple.of("input1", input1, "input2", input2)
+ * .apply(CoGroup.join(By.fieldNames("user")).crossProductJoin();
+ * }</pre>
+ *
+ * <p>The schema of the output PCollection contains a nested message for each of input1 and input2.
+ * Like above, you could use the {@link Convert} transform to convert it to the following POJO:
+ *
+ * <pre>{@code
+ * {@literal @}DefaultSchema(JavaFieldSchema.class)
+ * public class JoinedValue {
+ * public Input1Type input1;
+ * public Input2Type input2;
+ * }
+ * }</pre>
+ *
+ * <p>The {@link Unnest} transform can then be used to flatten all the subfields into one single
+ * top-level row containing all the fields in both Input1 and Input2; this will often be combined
+ * with a {@link Select} transform to select out the fields of interest, as the key fields will be
+ * identical between input1 and input2.
+ *
+ * <p>This transform also supports outer-join semantics. By default, all input PCollections must
+ * participate fully in the join, providing inner-join semantics. This means that the join will only
+ * produce values for "Bob" if all inputs have values for "Bob;" if even a single input does not
+ * have a value for "Bob," an inner-join will produce no value. However, if you mark that input as
+ * having outer-join participation then the join will contain values for "Bob," as long as at least
+ * one input has a "Bob" value; null values will be added for inputs that have no "Bob" values. To
+ * continue the SQL example:
+ *
+ * <p>SELECT * FROM input1 LEFT OUTER JOIN input2 ON input1.user = input2.user
+ *
+ * <p>Is equivalent to:
+ *
+ * <pre>{@code
+ * PCollection<Row> joined = PCollectionTuple.of("input1", input1, "input2", input2)
+ * .apply(CoGroup.join("input1", By.fieldNames("user").withOuterJoinParticipation())
+ * .join("input2", By.fieldNames("user"))
+ * .crossProductJoin();
+ * }</pre>
+ *
+ * <p>SELECT * FROM input1 RIGHT OUTER JOIN input2 ON input1.user = input2.user
+ *
+ * <p>Is equivalent to:
+ *
+ * <pre>{@code
+ * PCollection<Row> joined = PCollectionTuple.of("input1", input1, "input2", input2)
+ * .apply(CoGroup.join("input1", By.fieldNames("user"))
+ * .join("input2", By.fieldNames("user").withOuterJoinParticipation())
+ * .crossProductJoin();
+ * }</pre>
+ *
+ * <p>and SELECT * FROM input1 FULL OUTER JOIN input2 ON input1.user = input2.user
+ *
+ * <p>Is equivalent to:
+ *
+ * <pre>{@code
+ * PCollection<Row> joined = PCollectionTuple.of("input1", input1, "input2", input2)
+ * .apply(CoGroup.join("input1", By.fieldNames("user").withOuterJoinParticipation())
+ * .join("input2", By.fieldNames("user").withOuterJoinParticipation())
+ * .crossProductJoin();
+ * }</pre>
+ *
+ * <p>While the above examples use two inputs to mimic SQL's left and right join semantics, the
+ * {@link CoGroup} transform supports any number of inputs, and outer-join participation can be
+ * specified on any subset of them.
+ *
+ * <p>Do note that cross-product joins while simpler and easier to program, can cause
*/
public class CoGroup {
- /**
- * Join by the following field names.
- *
- * <p>The same field names are used in all input PCollections.
- */
- public static Inner byFieldNames(String... fieldNames) {
- return byFieldAccessDescriptor(FieldAccessDescriptor.withFieldNames(fieldNames));
- }
+ private static final List NULL_LIST;
- /**
- * Join by the following field ids.
- *
- * <p>The same field ids are used in all input PCollections.
- */
- public static Inner byFieldIds(Integer... fieldIds) {
- return byFieldAccessDescriptor(FieldAccessDescriptor.withFieldIds(fieldIds));
+ static {
+ NULL_LIST = Lists.newArrayList();
+ NULL_LIST.add(null);
}
/**
- * Join by the following {@link FieldAccessDescriptor}.
- *
- * <p>The same access descriptor is used in all input PCollections.
+ * Defines the set of fields to extract for the join key, as well as other per-input join options.
*/
- public static Inner byFieldAccessDescriptor(FieldAccessDescriptor fieldAccessDescriptor) {
- return new Inner(fieldAccessDescriptor);
- }
+ @AutoValue
+ public abstract static class By implements Serializable {
+ abstract FieldAccessDescriptor getFieldAccessDescriptor();
- /**
- * Select the following field names for the specified PCollection.
- *
- * <p>Each PCollection in the input must have fields specified for the join key.
- */
- public static Inner byFieldNames(TupleTag<?> tag, String... fieldNames) {
- return byFieldAccessDescriptor(tag, FieldAccessDescriptor.withFieldNames(fieldNames));
- }
+ abstract boolean getOuterJoinParticipation();
- /**
- * Select the following field ids for the specified PCollection.
- *
- * <p>Each PCollection in the input must have fields specified for the join key.
- */
- public static Inner byFieldIds(TupleTag<?> tag, Integer... fieldIds) {
- return byFieldAccessDescriptor(tag, FieldAccessDescriptor.withFieldIds(fieldIds));
- }
+ abstract Builder toBuilder();
- /**
- * Select the following fields for the specified PCollection using {@link FieldAccessDescriptor}.
- *
- * <p>Each PCollection in the input must have fields specified for the join key.
- */
- public static Inner byFieldAccessDescriptor(
- TupleTag<?> tag, FieldAccessDescriptor fieldAccessDescriptor) {
- return new Inner().byFieldAccessDescriptor(tag, fieldAccessDescriptor);
- }
+ @AutoValue.Builder
+ abstract static class Builder {
+ abstract Builder setFieldAccessDescriptor(FieldAccessDescriptor fieldAccessDescriptor);
- /** The implementing PTransform. */
- public static class Inner extends PTransform<PCollectionTuple, PCollection<KV<Row, Row>>> {
- @Nullable private final FieldAccessDescriptor allInputsFieldAccessDescriptor;
- private final Map<TupleTag<?>, FieldAccessDescriptor> fieldAccessDescriptorMap;
+ abstract Builder setOuterJoinParticipation(boolean outerJoinParticipation);
- private Inner() {
- this(Collections.emptyMap());
+ abstract By build();
}
- private Inner(Map<TupleTag<?>, FieldAccessDescriptor> fieldAccessDescriptorMap) {
- this.allInputsFieldAccessDescriptor = null;
- this.fieldAccessDescriptorMap = fieldAccessDescriptorMap;
+ /** Join by the following field names. */
+ public static By fieldNames(String... fieldNames) {
+ return fieldAccessDescriptor(FieldAccessDescriptor.withFieldNames(fieldNames));
}
- private Inner(FieldAccessDescriptor allInputsFieldAccessDescriptor) {
- this.allInputsFieldAccessDescriptor = allInputsFieldAccessDescriptor;
- this.fieldAccessDescriptorMap = Collections.emptyMap();
+ /** Join by the following field ids. */
+ public static By fieldIds(Integer... fieldIds) {
+ return fieldAccessDescriptor(FieldAccessDescriptor.withFieldIds(fieldIds));
}
- /**
- * Join by the following field names.
- *
- * <p>The same field names are used in all input PCollections.
- */
- public Inner byFieldNames(TupleTag<?> tag, String... fieldNames) {
- return byFieldAccessDescriptor(tag, FieldAccessDescriptor.withFieldNames(fieldNames));
+ /** Join by the following field access descriptor. */
+ public static By fieldAccessDescriptor(FieldAccessDescriptor fieldAccessDescriptor) {
+ return new AutoValue_CoGroup_By.Builder()
+ .setFieldAccessDescriptor(fieldAccessDescriptor)
+ .setOuterJoinParticipation(false)
+ .build();
}
/**
- * Select the following field ids for the specified PCollection.
+ * Means that this field will participate in a join even when not present, similar to SQL
+ * outer-join semantics. Missing entries will be replaced by nulls.
*
- * <p>Each PCollection in the input must have fields specified for the join key.
+ * <p>This only affects the results of expandCrossProduct.
*/
- public Inner byFieldIds(TupleTag<?> tag, Integer... fieldIds) {
- return byFieldAccessDescriptor(tag, FieldAccessDescriptor.withFieldIds(fieldIds));
+ public By withOuterJoinParticipation() {
+ return toBuilder().setOuterJoinParticipation(true).build();
}
+ }
- /**
- * Select the following fields for the specified PCollection using {@link
- * FieldAccessDescriptor}.
- *
- * <p>Each PCollection in the input must have fields specified for the join key.
- */
- public Inner byFieldAccessDescriptor(
- TupleTag<?> tag, FieldAccessDescriptor fieldAccessDescriptor) {
- if (allInputsFieldAccessDescriptor != null) {
- throw new IllegalStateException("Cannot set both a global and per-tag fields.");
- }
- return new Inner(
- new ImmutableMap.Builder<TupleTag<?>, FieldAccessDescriptor>()
- .putAll(fieldAccessDescriptorMap)
- .put(tag, fieldAccessDescriptor)
- .build());
+ private static class JoinArguments implements Serializable {
+ @Nullable private final By allInputsJoinArgs;
+ private final Map<String, By> joinArgsMap;
+
+ JoinArguments(@Nullable By allInputsJoinArgs) {
+ this.allInputsJoinArgs = allInputsJoinArgs;
+ this.joinArgsMap = Collections.emptyMap();
+ }
+
+ JoinArguments(Map<String, By> joinArgsMap) {
+ this.allInputsJoinArgs = null;
+ this.joinArgsMap = joinArgsMap;
+ }
+
+ JoinArguments with(String tag, By clause) {
+ return new JoinArguments(
+ new ImmutableMap.Builder<String, By>().putAll(joinArgsMap).put(tag, clause).build());
}
@Nullable
- private FieldAccessDescriptor getFieldAccessDescriptor(TupleTag<?> tag) {
- return (allInputsFieldAccessDescriptor != null)
- ? allInputsFieldAccessDescriptor
- : fieldAccessDescriptorMap.get(tag);
+ private FieldAccessDescriptor getFieldAccessDescriptor(String tag) {
+ return (allInputsJoinArgs != null)
+ ? allInputsJoinArgs.getFieldAccessDescriptor()
+ : joinArgsMap.get(tag).getFieldAccessDescriptor();
}
- @Override
- public PCollection<KV<Row, Row>> expand(PCollectionTuple input) {
+ private boolean getOuterJoinParticipation(String tag) {
+ return (allInputsJoinArgs != null)
+ ? allInputsJoinArgs.getOuterJoinParticipation()
+ : joinArgsMap.get(tag).getOuterJoinParticipation();
+ }
+ }
+
+ /**
+ * Join all input PCollections using the same args.
+ *
+ * <p>The same fields and other options are used in all input PCollections.
+ */
+ public static Inner join(By clause) {
+ return new Inner(new JoinArguments(clause));
+ }
+
+ /**
+ * Specify the following join arguments (including fields to join by_ for the specified
+ * PCollection.
+ *
+ * <p>Each PCollection in the input must have args specified for the join key.
+ */
+ public static Inner join(String tag, By clause) {
+ return new Inner(new JoinArguments(ImmutableMap.of(tag, clause)));
+ }
+
+ // Contains summary information needed for implementing the join.
+ private static class JoinInformation {
+ private final KeyedPCollectionTuple<Row> keyedPCollectionTuple;
+ private final Schema keySchema;
+ private final Map<String, Schema> componentSchemas;
+ // Maps from index in sortedTags to the toRow function.
+ private final Map<Integer, SerializableFunction<Object, Row>> toRows;
+ private final List<String> sortedTags;
+ private final Map<Integer, String> tagToKeyedTag;
+
+ private JoinInformation(
+ KeyedPCollectionTuple<Row> keyedPCollectionTuple,
+ Schema keySchema,
+ Map<String, Schema> componentSchemas,
+ Map<Integer, SerializableFunction<Object, Row>> toRows,
+ List<String> sortedTags,
+ Map<Integer, String> tagToKeyedTag) {
+ this.keyedPCollectionTuple = keyedPCollectionTuple;
+ this.keySchema = keySchema;
+ this.componentSchemas = componentSchemas;
+ this.toRows = toRows;
+ this.sortedTags = sortedTags;
+ this.tagToKeyedTag = tagToKeyedTag;
+ }
+
+ private static JoinInformation from(
+ PCollectionTuple input, Function<String, FieldAccessDescriptor> getFieldAccessDescriptor) {
KeyedPCollectionTuple<Row> keyedPCollectionTuple =
KeyedPCollectionTuple.empty(input.getPipeline());
- List<TupleTag<Row>> sortedTags =
+
+ List<String> sortedTags =
input.getAll().keySet().stream()
- .sorted(Comparator.comparing(TupleTag::getId))
- .map(t -> new TupleTag<Row>(t.getId() + "_ROW"))
+ .map(TupleTag::getId)
+ .sorted()
.collect(Collectors.toList());
// Keep this in a TreeMap so that it's sorted. This way we get a deterministic output
// schema.
TreeMap<String, Schema> componentSchemas = Maps.newTreeMap();
- Map<String, SerializableFunction<Object, Row>> toRows = Maps.newHashMap();
+ Map<Integer, SerializableFunction<Object, Row>> toRows = Maps.newHashMap();
+ Map<Integer, String> tagToKeyedTag = Maps.newHashMap();
Schema keySchema = null;
for (Map.Entry<TupleTag<?>, PCollection<?>> entry : input.getAll().entrySet()) {
- TupleTag<?> tag = entry.getKey();
+ String tag = entry.getKey().getId();
+ int tagIndex = sortedTags.indexOf(tag);
PCollection<?> pc = entry.getValue();
Schema schema = pc.getSchema();
- componentSchemas.put(tag.getId(), schema);
- TupleTag<Row> rowTag = new TupleTag<>(tag.getId() + "_ROW");
- toRows.put(rowTag.getId(), (SerializableFunction<Object, Row>) pc.getToRowFunction());
- FieldAccessDescriptor fieldAccessDescriptor = getFieldAccessDescriptor(tag);
+ componentSchemas.put(tag, schema);
+ toRows.put(tagIndex, (SerializableFunction<Object, Row>) pc.getToRowFunction());
+ FieldAccessDescriptor fieldAccessDescriptor = getFieldAccessDescriptor.apply(tag);
if (fieldAccessDescriptor == null) {
throw new IllegalStateException("No fields were set for input " + tag);
}
@@ -275,51 +369,150 @@ public class CoGroup {
}
}
+ // Create a new tag for the output.
+ TupleTag randomTag = new TupleTag<>();
+ String keyedTag = tag + "_" + randomTag;
+ tagToKeyedTag.put(tagIndex, keyedTag);
PCollection<KV<Row, Row>> keyedPCollection =
- extractKey(pc, schema, keySchema, resolved, tag.getId());
- keyedPCollectionTuple = keyedPCollectionTuple.and(rowTag, keyedPCollection);
+ extractKey(pc, schema, keySchema, resolved, tag);
+ keyedPCollectionTuple = keyedPCollectionTuple.and(keyedTag, keyedPCollection);
}
+ return new JoinInformation(
+ keyedPCollectionTuple, keySchema, componentSchemas, toRows, sortedTags, tagToKeyedTag);
+ }
+ private static <T> PCollection<KV<Row, Row>> extractKey(
+ PCollection<T> pCollection,
+ Schema schema,
+ Schema keySchema,
+ FieldAccessDescriptor keyFields,
+ String tag) {
+ return pCollection
+ .apply(
+ "extractKey" + tag,
+ ParDo.of(
+ new DoFn<T, KV<Row, Row>>() {
+ @ProcessElement
+ public void process(@Element Row row, OutputReceiver<KV<Row, Row>> o) {
+ o.output(KV.of(Select.selectRow(row, keyFields, schema, keySchema), row));
+ }
+ }))
+ .setCoder(KvCoder.of(SchemaCoder.of(keySchema), SchemaCoder.of(schema)));
+ }
+ }
+
+ static void verify(PCollectionTuple input, JoinArguments joinArgs) {
+ if (joinArgs.allInputsJoinArgs == null) {
+ // If explicit join tags were specified, then they must match the input tuple.
+ Set<String> inputTags =
+ input.getAll().keySet().stream().map(TupleTag::getId).collect(Collectors.toSet());
+ Set<String> joinTags = joinArgs.joinArgsMap.keySet();
+ if (!inputTags.equals(joinTags)) {
+ throw new IllegalArgumentException(
+ "The input PCollectionTuple has tags: "
+ + inputTags
+ + " and the join was specified for tags "
+ + joinTags
+ + ". These do not match.");
+ }
+ }
+ }
+
+ /** The implementing PTransform. */
+ public static class Inner extends PTransform<PCollectionTuple, PCollection<KV<Row, Row>>> {
+ private final JoinArguments joinArgs;
+
+ private Inner() {
+ this(new JoinArguments(Collections.emptyMap()));
+ }
+
+ private Inner(JoinArguments joinArgs) {
+ this.joinArgs = joinArgs;
+ }
+
+ /**
+ * Select the following fields for the specified PCollection with the specified join args.
+ *
+ * <p>Each PCollection in the input must have fields specified for the join key.
+ */
+ public Inner join(String tag, By clause) {
+ if (joinArgs.allInputsJoinArgs != null) {
+ throw new IllegalStateException("Cannot set both a global and per-tag fields.");
+ }
+ return new Inner(joinArgs.with(tag, clause));
+ }
+
+ /** Expand the join into individual rows, similar to SQL joins. */
+ public ExpandCrossProduct crossProductJoin() {
+ return new ExpandCrossProduct(joinArgs);
+ }
+
+ private Schema getOutputSchema(JoinInformation joinInformation) {
// Construct the output schema. It contains one field for each input PCollection, of type
// ARRAY[ROW].
Schema.Builder joinedSchemaBuilder = Schema.builder();
- for (Map.Entry<String, Schema> entry : componentSchemas.entrySet()) {
+ for (Map.Entry<String, Schema> entry : joinInformation.componentSchemas.entrySet()) {
joinedSchemaBuilder.addArrayField(entry.getKey(), FieldType.row(entry.getValue()));
}
- Schema joinedSchema = joinedSchemaBuilder.build();
+ return joinedSchemaBuilder.build();
+ }
+
+ @Override
+ public PCollection<KV<Row, Row>> expand(PCollectionTuple input) {
+ verify(input, joinArgs);
- return keyedPCollectionTuple
+ JoinInformation joinInformation =
+ JoinInformation.from(input, joinArgs::getFieldAccessDescriptor);
+
+ Schema joinedSchema = getOutputSchema(joinInformation);
+
+ return joinInformation
+ .keyedPCollectionTuple
.apply("CoGroupByKey", CoGroupByKey.create())
- .apply("ConvertToRow", ParDo.of(new ConvertToRow(sortedTags, toRows, joinedSchema)))
- .setCoder(KvCoder.of(SchemaCoder.of(keySchema), SchemaCoder.of(joinedSchema)));
+ .apply(
+ "ConvertToRow",
+ ParDo.of(
+ new ConvertToRow(
+ joinInformation.sortedTags,
+ joinInformation.toRows,
+ joinedSchema,
+ joinInformation.tagToKeyedTag)))
+ .setCoder(
+ KvCoder.of(SchemaCoder.of(joinInformation.keySchema), SchemaCoder.of(joinedSchema)));
}
+ // Used by the unexpanded join to create the output rows.
private static class ConvertToRow extends DoFn<KV<Row, CoGbkResult>, KV<Row, Row>> {
- List<TupleTag<Row>> sortedTags;
- Map<String, SerializableFunction<Object, Row>> toRows = Maps.newHashMap();
- Schema joinedSchema;
-
- public ConvertToRow(
- List<TupleTag<Row>> sortedTags,
- Map<String, SerializableFunction<Object, Row>> toRows,
- Schema joinedSchema) {
+ private final List<String> sortedTags;
+ private final Map<Integer, SerializableFunction<Object, Row>> toRows;
+ private final Map<Integer, String> tagToKeyedTag;
+ private final Schema joinedSchema;
+
+ ConvertToRow(
+ List<String> sortedTags,
+ Map<Integer, SerializableFunction<Object, Row>> toRows,
+ Schema joinedSchema,
+ Map<Integer, String> tagToKeyedTag) {
this.sortedTags = sortedTags;
this.toRows = toRows;
this.joinedSchema = joinedSchema;
+ this.tagToKeyedTag = tagToKeyedTag;
}
@ProcessElement
public void process(@Element KV<Row, CoGbkResult> kv, OutputReceiver<KV<Row, Row>> o) {
Row key = kv.getKey();
CoGbkResult result = kv.getValue();
- List<Object> fields = Lists.newArrayListWithExpectedSize(sortedTags.size());
- for (TupleTag<?> tag : sortedTags) {
+ List<Object> fields = Lists.newArrayListWithCapacity(sortedTags.size());
+ for (int i = 0; i < sortedTags.size(); ++i) {
+ String tag = sortedTags.get(i);
// TODO: This forces the entire join to materialize in memory. We should create a
// lazy Row interface on top of the iterable returned by CoGbkResult. This will
- // allow the data to be streamed in.
- SerializableFunction<Object, Row> toRow = toRows.get(tag.getId());
+ // allow the data to be streamed in. Tracked in [BEAM-6756].
+ SerializableFunction<Object, Row> toRow = toRows.get(i);
+ String tupleTag = tagToKeyedTag.get(i);
List<Row> joined = Lists.newArrayList();
- for (Object item : result.getAll(tag)) {
+ for (Object item : result.getAll(tupleTag)) {
joined.add(toRow.apply(item));
}
fields.add(joined);
@@ -327,24 +520,145 @@ public class CoGroup {
o.output(KV.of(key, Row.withSchema(joinedSchema).addValues(fields).build()));
}
}
+ }
- private static <T> PCollection<KV<Row, Row>> extractKey(
- PCollection<T> pCollection,
- Schema schema,
- Schema keySchema,
- FieldAccessDescriptor keyFields,
- String tag) {
- return pCollection
+ /** A {@link PTransform} that calculates the cross-product join. */
+ public static class ExpandCrossProduct extends PTransform<PCollectionTuple, PCollection<Row>> {
+ private final JoinArguments joinArgs;
+
+ ExpandCrossProduct(JoinArguments joinArgs) {
+ this.joinArgs = joinArgs;
+ }
+
+ /**
+ * Select the following fields for the specified PCollection with the specified join args.
+ *
+ * <p>Each PCollection in the input must have fields specified for the join key.
+ */
+ public ExpandCrossProduct join(String tag, By clause) {
+ if (joinArgs.allInputsJoinArgs != null) {
+ throw new IllegalStateException("Cannot set both a global and per-tag fields.");
+ }
+ return new ExpandCrossProduct(joinArgs.with(tag, clause));
+ }
+
+ private Schema getOutputSchema(JoinInformation joinInformation) {
+ // Construct the output schema. It contains one field for each input PCollection, of type
+ // ROW. If a field supports outer-join semantics, then that field will be nullable in the
+ // schema.
+ Schema.Builder joinedSchemaBuilder = Schema.builder();
+ for (Map.Entry<String, Schema> entry : joinInformation.componentSchemas.entrySet()) {
+ FieldType fieldType = FieldType.row(entry.getValue());
+ if (joinArgs.getOuterJoinParticipation(entry.getKey())) {
+ fieldType = fieldType.withNullable(true);
+ }
+ joinedSchemaBuilder.addField(entry.getKey(), fieldType);
+ }
+ return joinedSchemaBuilder.build();
+ }
+
+ @Override
+ public PCollection<Row> expand(PCollectionTuple input) {
+ verify(input, joinArgs);
+
+ JoinInformation joinInformation =
+ JoinInformation.from(input, joinArgs::getFieldAccessDescriptor);
+
+ Schema joinedSchema = getOutputSchema(joinInformation);
+
+ return joinInformation
+ .keyedPCollectionTuple
+ .apply("CoGroupByKey", CoGroupByKey.create())
+ .apply("Values", Values.create())
.apply(
- "extractKey" + tag,
+ "ExpandToRow",
ParDo.of(
- new DoFn<T, KV<Row, Row>>() {
- @ProcessElement
- public void process(@Element Row row, OutputReceiver<KV<Row, Row>> o) {
- o.output(KV.of(Select.selectRow(row, keyFields, schema, keySchema), row));
- }
- }))
- .setCoder(KvCoder.of(SchemaCoder.of(keySchema), SchemaCoder.of(schema)));
+ new ExpandToRows(
+ joinInformation.sortedTags,
+ joinInformation.toRows,
+ joinedSchema,
+ joinInformation.tagToKeyedTag)))
+ .setRowSchema(joinedSchema);
+ }
+
+ /** A DoFn that expands the result of a CoGroupByKey into the cross product. */
+ private class ExpandToRows extends DoFn<CoGbkResult, Row> {
+ private final List<String> sortedTags;
+ private final Map<Integer, SerializableFunction<Object, Row>> toRows;
+ private final Schema outputSchema;
+ private final Map<Integer, String> tagToKeyedTag;
+
+ public ExpandToRows(
+ List<String> sortedTags,
+ Map<Integer, SerializableFunction<Object, Row>> toRows,
+ Schema outputSchema,
+ Map<Integer, String> tagToKeyedTag) {
+ this.sortedTags = sortedTags;
+ this.toRows = toRows;
+ this.outputSchema = outputSchema;
+ this.tagToKeyedTag = tagToKeyedTag;
+ }
+
+ @ProcessElement
+ public void process(@Element CoGbkResult gbkResult, OutputReceiver<Row> o) {
+ List<Iterable> allIterables = extractIterables(gbkResult);
+ List<Row> accumulatedRows = Lists.newArrayListWithCapacity(sortedTags.size());
+ crossProduct(0, accumulatedRows, allIterables, o);
+ }
+
+ private List<Iterable> extractIterables(CoGbkResult gbkResult) {
+ List<Iterable> iterables = Lists.newArrayListWithCapacity(sortedTags.size());
+ for (int i = 0; i < sortedTags.size(); ++i) {
+ String tag = sortedTags.get(i);
+ Iterable items = gbkResult.getAll(tagToKeyedTag.get(i));
+ if (!items.iterator().hasNext() && joinArgs.getOuterJoinParticipation(tag)) {
+ // If this tag has outer-join participation, then empty should participate as a
+ // single null.
+ items = () -> NULL_LIST.iterator();
+ }
+ iterables.add(items);
+ }
+ return iterables;
+ }
+
+ private void crossProduct(
+ int tagIndex,
+ List<Row> accumulatedRows,
+ List<Iterable> iterables,
+ OutputReceiver<Row> o) {
+ if (tagIndex >= sortedTags.size()) {
+ return;
+ }
+
+ SerializableFunction<Object, Row> toRow = toRows.get(tagIndex);
+ for (Object item : iterables.get(tagIndex)) {
+ // For every item that joined for the current input, and recurse down to calculate the
+ // list of expanded records.
+ Row row = toRow.apply(item);
+ crossProductHelper(tagIndex, accumulatedRows, row, iterables, o);
+ }
+ }
+
+ private void crossProductHelper(
+ int tagIndex,
+ List<Row> accumulatedRows,
+ Row newRow,
+ List<Iterable> iterables,
+ OutputReceiver<Row> o) {
+ boolean atBottom = tagIndex == sortedTags.size() - 1;
+ accumulatedRows.add(newRow);
+ if (atBottom) {
+ // Bottom of recursive call, so output the row we've accumulated.
+ o.output(buildOutputRow(accumulatedRows));
+ } else {
+ crossProduct(tagIndex + 1, accumulatedRows, iterables, o);
+ }
+ accumulatedRows.remove(accumulatedRows.size() - 1);
+ }
+
+ private Row buildOutputRow(List rows) {
+ return Row.withSchema(outputSchema).addValues(rows).build();
+ }
}
}
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java
index 1ed6a28..3496558 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java
@@ -146,6 +146,19 @@ public class Create<T> {
}
/**
+ * Returns a new {@code Create.Values} transform that produces an empty {@link PCollection} of
+ * rows.
+ */
+ public static Values<Row> empty(Schema schema) {
+ return new Values<Row>(
+ new ArrayList<>(),
+ Optional.of(
+ SchemaCoder.of(
+ schema, SerializableFunctions.identity(), SerializableFunctions.identity())),
+ Optional.absent());
+ }
+
+ /**
* Returns a new {@code Create.Values} transform that produces an empty {@link PCollection}.
*
* <p>The elements will have a timestamp of negative infinity, see {@link Create#timestamped} for
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/CoGbkResult.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/CoGbkResult.java
index 27dc405..df97041 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/CoGbkResult.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/CoGbkResult.java
@@ -167,6 +167,11 @@ public class CoGbkResult {
return unions;
}
+ /** Like {@link #getAll(TupleTag)} but using a String instead of a {@link TupleTag}. */
+ public <V> Iterable<V> getAll(String tag) {
+ return getAll(new TupleTag<>(tag));
+ }
+
/**
* If there is a singleton value for the given tag, returns it. Otherwise, throws an
* IllegalArgumentException.
@@ -178,6 +183,12 @@ public class CoGbkResult {
return innerGetOnly(tag, null, false);
}
+ /** Like {@link #getOnly(TupleTag)} but using a String instead of a TupleTag. */
+ @SuppressWarnings("TypeParameterUnusedInFormals")
+ public <V> V getOnly(String tag) {
+ return getOnly(new TupleTag<>(tag));
+ }
+
/**
* If there is a singleton value for the given tag, returns it. If there is no value for the given
* tag, returns the defaultValue.
@@ -190,6 +201,12 @@ public class CoGbkResult {
return innerGetOnly(tag, defaultValue, true);
}
+ /** Like {@link #getOnly(TupleTag, Object)} but using a String instead of a TupleTag. */
+ @Nullable
+ public <V> V getOnly(String tag, @Nullable V defaultValue) {
+ return getOnly(new TupleTag<>(tag), defaultValue);
+ }
+
/** A {@link Coder} for {@link CoGbkResult}s. */
public static class CoGbkResultCoder extends CustomCoder<CoGbkResult> {
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/KeyedPCollectionTuple.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/KeyedPCollectionTuple.java
index 5ebc1d5..7bb2781 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/KeyedPCollectionTuple.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/KeyedPCollectionTuple.java
@@ -53,6 +53,16 @@ public class KeyedPCollectionTuple<K> implements PInput {
}
/**
+ * A version of {@link #of(TupleTag, PCollection)} that takes in a string instead of a TupleTag.
+ *
+ * <p>This method is simpler for cases when a typed tuple-tag is not needed to extract a
+ * PCollection, for example when using schema transforms.
+ */
+ public static <K, InputT> KeyedPCollectionTuple<K> of(String tag, PCollection<KV<K, InputT>> pc) {
+ return of(new TupleTag<>(tag), pc);
+ }
+
+ /**
* Returns a new {@code KeyedPCollectionTuple<K>} that is the same as this, appended with the
* given PCollection.
*/
@@ -67,6 +77,16 @@ public class KeyedPCollectionTuple<K> implements PInput {
getPipeline(), newKeyedCollections, schema.getTupleTagList().and(tag), myKeyCoder);
}
+ /**
+ * A version of {@link #and(String, PCollection)} that takes in a string instead of a TupleTag.
+ *
+ * <p>This method is simpler for cases when a typed tuple-tag is not needed to extract a
+ * PCollection, for example when using schema transforms.
+ */
+ public <V> KeyedPCollectionTuple<K> and(String tag, PCollection<KV<K, V>> pc) {
+ return and(new TupleTag<>(tag), pc);
+ }
+
public boolean isEmpty() {
return keyedCollections.isEmpty();
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java
index c40d683..92fe0ee 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java
@@ -92,6 +92,73 @@ public class PCollectionTuple implements PInput, POutput {
}
/**
+ * A version of {@link #of(TupleTag, PCollection)} that takes in a String instead of a {@link
+ * TupleTag}.
+ *
+ * <p>This method is simpler for cases when a typed tuple-tag is not needed to extract a
+ * PCollection, for example when using schema transforms.
+ */
+ public static <T> PCollectionTuple of(String tag, PCollection<T> pc) {
+ return of(new TupleTag<>(tag), pc);
+ }
+
+ /**
+ * A version of {@link #of(String, PCollection)} that takes in two PCollections of the same type.
+ */
+ public static <T> PCollectionTuple of(
+ String tag1, PCollection<T> pc1, String tag2, PCollection<T> pc2) {
+ return of(tag1, pc1).and(tag2, pc2);
+ }
+
+ /**
+ * A version of {@link #of(String, PCollection)} that takes in three PCollections of the same
+ * type.
+ */
+ public static <T> PCollectionTuple of(
+ String tag1,
+ PCollection<T> pc1,
+ String tag2,
+ PCollection<T> pc2,
+ String tag3,
+ PCollection<T> pc3) {
+ return of(tag1, pc1, tag2, pc2).and(tag3, pc3);
+ }
+
+ /**
+ * A version of {@link #of(String, PCollection)} that takes in four PCollections of the same type.
+ */
+ public static <T> PCollectionTuple of(
+ String tag1,
+ PCollection<T> pc1,
+ String tag2,
+ PCollection<T> pc2,
+ String tag3,
+ PCollection<T> pc3,
+ String tag4,
+ PCollection<T> pc4) {
+ return of(tag1, pc1, tag2, pc2, tag3, pc3).and(tag4, pc4);
+ }
+
+ /**
+ * A version of {@link #of(String, PCollection)} that takes in five PCollections of the same type.
+ */
+ public static <T> PCollectionTuple of(
+ String tag1,
+ PCollection<T> pc1,
+ String tag2,
+ PCollection<T> pc2,
+ String tag3,
+ PCollection<T> pc3,
+ String tag4,
+ PCollection<T> pc4,
+ String tag5,
+ PCollection<T> pc5) {
+ return of(tag1, pc1, tag2, pc2, tag3, pc3, tag4, pc4).and(tag5, pc5);
+ }
+
+ // To create a PCollectionTuple with more than five inputs, use the and() builder method.
+
+ /**
* Returns a new {@link PCollectionTuple} that has each {@link PCollection} and {@link TupleTag}
* of this {@link PCollectionTuple} plus the given {@link PCollection} associated with the given
* {@link TupleTag}.
@@ -116,6 +183,16 @@ public class PCollectionTuple implements PInput, POutput {
}
/**
+ * A version of {@link #and(TupleTag, PCollection)} that takes in a String instead of a TupleTag.
+ *
+ * <p>This method is simpler for cases when a typed tuple-tag is not needed to extract a
+ * PCollection, for example when using schema transforms.
+ */
+ public <T> PCollectionTuple and(String tag, PCollection<T> pc) {
+ return and(new TupleTag<>(tag), pc);
+ }
+
+ /**
* Returns whether this {@link PCollectionTuple} contains a {@link PCollection} with the given
* tag.
*/
@@ -124,6 +201,14 @@ public class PCollectionTuple implements PInput, POutput {
}
/**
+ * Returns whether this {@link PCollectionTuple} contains a {@link PCollection} with the given
+ * tag.
+ */
+ public <T> boolean has(String tag) {
+ return has(new TupleTag<>(tag));
+ }
+
+ /**
* Returns the {@link PCollection} associated with the given {@link TupleTag} in this {@link
* PCollectionTuple}. Throws {@link IllegalArgumentException} if there is no such {@link
* PCollection}, i.e., {@code !has(tag)}.
@@ -138,6 +223,15 @@ public class PCollectionTuple implements PInput, POutput {
}
/**
+ * Returns the {@link PCollection} associated with the given tag in this {@link PCollectionTuple}.
+ * Throws {@link IllegalArgumentException} if there is no such {@link PCollection}, i.e., {@code
+ * !has(tag)}.
+ */
+ public <T> PCollection<T> get(String tag) {
+ return get(new TupleTag<>(tag));
+ }
+
+ /**
* Returns an immutable Map from {@link TupleTag} to corresponding {@link PCollection}, for all
* the members of this {@link PCollectionTuple}.
*/
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java
index ac11462..47958e2 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java
@@ -544,8 +544,8 @@ public abstract class Row implements Serializable {
if (schema.getFieldCount() != values.size()) {
throw new IllegalArgumentException(
String.format(
- "Field count in Schema (%s) and values (%s) must match",
- schema.getFieldNames(), values));
+ "Field count in Schema (%s) (%d) and values (%s) (%d) must match",
+ schema.getFieldNames(), schema.getFieldCount(), values, values.size()));
}
for (int i = 0; i < values.size(); ++i) {
Object value = values.get(i);
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java
index 35bd657..2251c6c 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java
@@ -23,11 +23,14 @@ import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder;
import static org.junit.Assert.assertThat;
+import java.util.Arrays;
import java.util.List;
+import java.util.stream.Collectors;
import org.apache.beam.sdk.TestUtils.KvMatcher;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.Schema.TypeName;
+import org.apache.beam.sdk.schemas.transforms.CoGroup.By;
import org.apache.beam.sdk.testing.NeedsRunner;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
@@ -36,7 +39,6 @@ import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.Row;
-import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Lists;
import org.hamcrest.BaseMatcher;
@@ -183,10 +185,8 @@ public class CoGroupTest {
.build();
PCollection<KV<Row, Row>> joined =
- PCollectionTuple.of(new TupleTag<>("pc1"), pc1)
- .and(new TupleTag<>("pc2"), pc2)
- .and(new TupleTag<>("pc3"), pc3)
- .apply("CoGroup", CoGroup.byFieldNames("user", "country"));
+ PCollectionTuple.of("pc1", pc1, "pc2", pc2, "pc3", pc3)
+ .apply("CoGroup", CoGroup.join(By.fieldNames("user", "country")));
List<KV<Row, Row>> expected =
ImmutableList.of(
KV.of(key1, key1Joined),
@@ -325,19 +325,13 @@ public class CoGroupTest {
Row.withSchema(CG_SCHEMA_3).addValues("user2", 24, "ar").build()))
.build();
- TupleTag<Row> pc1Tag = new TupleTag<>("pc1");
- TupleTag<Row> pc2Tag = new TupleTag<>("pc2");
- TupleTag<Row> pc3Tag = new TupleTag<>("pc3");
-
PCollection<KV<Row, Row>> joined =
- PCollectionTuple.of(pc1Tag, pc1)
- .and(pc2Tag, pc2)
- .and(pc3Tag, pc3)
+ PCollectionTuple.of("pc1", pc1, "pc2", pc2, "pc3", pc3)
.apply(
"CoGroup",
- CoGroup.byFieldNames(pc1Tag, "user", "country")
- .byFieldNames(pc2Tag, "user2", "country2")
- .byFieldNames(pc3Tag, "user3", "country3"));
+ CoGroup.join("pc1", By.fieldNames("user", "country"))
+ .join("pc2", By.fieldNames("user2", "country2"))
+ .join("pc3", By.fieldNames("user3", "country3")));
List<KV<Row, Row>> expected =
ImmutableList.of(
@@ -367,19 +361,14 @@ public class CoGroupTest {
PCollection<Row> pc3 =
pipeline.apply(
"Create3", Create.of(Row.withSchema(CG_SCHEMA_3).addValues("user1", 17, "us").build()));
- TupleTag<Row> pc1Tag = new TupleTag<>("pc1");
- TupleTag<Row> pc2Tag = new TupleTag<>("pc2");
- TupleTag<Row> pc3Tag = new TupleTag<>("pc3");
- thrown.expect(IllegalStateException.class);
+ thrown.expect(IllegalArgumentException.class);
PCollection<KV<Row, Row>> joined =
- PCollectionTuple.of(pc1Tag, pc1)
- .and(pc2Tag, pc2)
- .and(pc3Tag, pc3)
+ PCollectionTuple.of("pc1", pc1, "pc2", pc2, "pc3", pc3)
.apply(
"CoGroup",
- CoGroup.byFieldNames(pc1Tag, "user", "country")
- .byFieldNames(pc2Tag, "user2", "country2"));
+ CoGroup.join("pc1", By.fieldNames("user", "country"))
+ .join("pc2", By.fieldNames("user2", "country2")));
pipeline.run();
}
@@ -399,13 +388,318 @@ public class CoGroupTest {
Create.of(Row.withSchema(CG_SCHEMA_1).addValues("user1", 9, "us").build()))
.setRowSchema(CG_SCHEMA_1);
- TupleTag<Row> pc1Tag = new TupleTag<>("pc1");
- TupleTag<Row> pc2Tag = new TupleTag<>("pc2");
thrown.expect(IllegalStateException.class);
PCollection<KV<Row, Row>> joined =
- PCollectionTuple.of(pc1Tag, pc1)
- .and(pc2Tag, pc2)
- .apply("CoGroup", CoGroup.byFieldNames(pc1Tag, "user").byFieldNames(pc2Tag, "count"));
+ PCollectionTuple.of("pc1", pc1, "pc2", pc2)
+ .apply(
+ "CoGroup",
+ CoGroup.join("pc1", By.fieldNames("user")).join("pc2", By.fieldNames("count")));
+ pipeline.run();
+ }
+
+ private List<Row> innerJoin(
+ List<Row> inputs1,
+ List<Row> inputs2,
+ List<Row> inputs3,
+ String[] keys1,
+ String[] keys2,
+ String[] keys3,
+ Schema expectedSchema) {
+ List<Row> joined = Lists.newArrayList();
+ for (Row row1 : inputs1) {
+ for (Row row2 : inputs2) {
+ for (Row row3 : inputs3) {
+ List key1 = Arrays.stream(keys1).map(row1::getValue).collect(Collectors.toList());
+ List key2 = Arrays.stream(keys2).map(row2::getValue).collect(Collectors.toList());
+ List key3 = Arrays.stream(keys3).map(row3::getValue).collect(Collectors.toList());
+ if (key1.equals(key2) && key2.equals(key3)) {
+ joined.add(Row.withSchema(expectedSchema).addValues(row1, row2, row3).build());
+ }
+ }
+ }
+ }
+ return joined;
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testInnerJoin() {
+ List<Row> pc1Rows =
+ Lists.newArrayList(
+ Row.withSchema(CG_SCHEMA_1).addValues("user1", 1, "us").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user1", 2, "us").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user1", 3, "il").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user1", 4, "il").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user2", 5, "fr").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user2", 6, "fr").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build());
+ List<Row> pc2Rows =
+ Lists.newArrayList(
+ Row.withSchema(CG_SCHEMA_2).addValues("user1", 9, "us").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user1", 10, "us").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user1", 11, "il").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user1", 12, "il").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user2", 13, "fr").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user2", 14, "fr").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user2", 15, "ar").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user2", 16, "ar").build());
+ List<Row> pc3Rows =
+ Lists.newArrayList(
+ Row.withSchema(CG_SCHEMA_3).addValues("user1", 17, "us").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user1", 18, "us").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user1", 19, "il").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user1", 20, "il").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user2", 21, "fr").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user2", 22, "fr").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user2", 23, "ar").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user2", 24, "ar").build());
+
+ PCollection<Row> pc1 = pipeline.apply("Create1", Create.of(pc1Rows)).setRowSchema(CG_SCHEMA_1);
+ PCollection<Row> pc2 = pipeline.apply("Create2", Create.of(pc2Rows)).setRowSchema(CG_SCHEMA_2);
+ PCollection<Row> pc3 = pipeline.apply("Create3", Create.of(pc3Rows)).setRowSchema(CG_SCHEMA_3);
+
+ Schema expectedSchema =
+ Schema.builder()
+ .addRowField("pc1", CG_SCHEMA_1)
+ .addRowField("pc2", CG_SCHEMA_2)
+ .addRowField("pc3", CG_SCHEMA_3)
+ .build();
+
+ PCollection<Row> joined =
+ PCollectionTuple.of("pc1", pc1, "pc2", pc2, "pc3", pc3)
+ .apply(
+ "CoGroup",
+ CoGroup.join("pc1", By.fieldNames("user", "country"))
+ .join("pc2", By.fieldNames("user2", "country2"))
+ .join("pc3", By.fieldNames("user3", "country3"))
+ .crossProductJoin());
+ assertEquals(expectedSchema, joined.getSchema());
+
+ List<Row> expectedJoinedRows =
+ innerJoin(
+ pc1Rows,
+ pc2Rows,
+ pc3Rows,
+ new String[] {"user", "country"},
+ new String[] {"user2", "country2"},
+ new String[] {"user3", "country3"},
+ expectedSchema);
+
+ PAssert.that(joined).containsInAnyOrder(expectedJoinedRows);
+ pipeline.run();
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testFullOuterJoin() {
+ List<Row> pc1Rows =
+ Lists.newArrayList(
+ Row.withSchema(CG_SCHEMA_1).addValues("user1", 1, "us").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user1", 2, "us").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user1", 3, "il").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user1", 4, "il").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user2", 5, "fr").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user2", 6, "fr").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user3", 7, "ar").build());
+
+ List<Row> pc2Rows =
+ Lists.newArrayList(
+ Row.withSchema(CG_SCHEMA_2).addValues("user1", 9, "us").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user1", 10, "us").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user1", 11, "il").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user1", 12, "il").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user2", 13, "fr").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user2", 14, "fr").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user2", 15, "ar").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user2", 16, "ar").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user2", 16, "es").build());
+
+ List<Row> pc3Rows =
+ Lists.newArrayList(
+ Row.withSchema(CG_SCHEMA_3).addValues("user1", 17, "us").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user1", 18, "us").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user1", 19, "il").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user1", 20, "il").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user2", 21, "fr").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user2", 22, "fr").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user2", 23, "ar").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user2", 24, "ar").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user27", 24, "se").build());
+
+ PCollection<Row> pc1 = pipeline.apply("Create1", Create.of(pc1Rows)).setRowSchema(CG_SCHEMA_1);
+ PCollection<Row> pc2 = pipeline.apply("Create2", Create.of(pc2Rows)).setRowSchema(CG_SCHEMA_2);
+ PCollection<Row> pc3 = pipeline.apply("Create3", Create.of(pc3Rows)).setRowSchema(CG_SCHEMA_3);
+
+ // Full outer join, so any field might be null.
+ Schema expectedSchema =
+ Schema.builder()
+ .addNullableField("pc1", FieldType.row(CG_SCHEMA_1))
+ .addNullableField("pc2", FieldType.row(CG_SCHEMA_2))
+ .addNullableField("pc3", FieldType.row(CG_SCHEMA_3))
+ .build();
+
+ PCollection<Row> joined =
+ PCollectionTuple.of("pc1", pc1, "pc2", pc2, "pc3", pc3)
+ .apply(
+ "CoGroup",
+ CoGroup.join("pc1", By.fieldNames("user", "country").withOuterJoinParticipation())
+ .join("pc2", By.fieldNames("user2", "country2").withOuterJoinParticipation())
+ .join("pc3", By.fieldNames("user3", "country3").withOuterJoinParticipation())
+ .crossProductJoin());
+ assertEquals(expectedSchema, joined.getSchema());
+
+ List<Row> expectedJoinedRows =
+ innerJoin(
+ pc1Rows,
+ pc2Rows,
+ pc3Rows,
+ new String[] {"user", "country"},
+ new String[] {"user2", "country2"},
+ new String[] {"user3", "country3"},
+ expectedSchema);
+ // Manually add the outer-join rows to the list of expected results.
+ expectedJoinedRows.add(
+ Row.withSchema(expectedSchema)
+ .addValues(Row.withSchema(CG_SCHEMA_1).addValues("user3", 7, "ar").build(), null, null)
+ .build());
+ expectedJoinedRows.add(
+ Row.withSchema(expectedSchema)
+ .addValues(null, Row.withSchema(CG_SCHEMA_2).addValues("user2", 16, "es").build(), null)
+ .build());
+ expectedJoinedRows.add(
+ Row.withSchema(expectedSchema)
+ .addValues(
+ null, null, Row.withSchema(CG_SCHEMA_3).addValues("user27", 24, "se").build())
+ .build());
+
+ PAssert.that(joined).containsInAnyOrder(expectedJoinedRows);
+ pipeline.run();
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testPartialOuterJoin() {
+ List<Row> pc1Rows =
+ Lists.newArrayList(
+ Row.withSchema(CG_SCHEMA_1).addValues("user1", 1, "us").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user1", 2, "us").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user1", 3, "il").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user1", 4, "il").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user2", 5, "fr").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user2", 6, "fr").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(),
+ Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build());
+
+ List<Row> pc2Rows =
+ Lists.newArrayList(
+ Row.withSchema(CG_SCHEMA_2).addValues("user1", 9, "us").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user1", 10, "us").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user1", 11, "il").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user1", 12, "il").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user2", 13, "fr").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user2", 14, "fr").build(),
+ Row.withSchema(CG_SCHEMA_2).addValues("user3", 7, "ar").build());
+
+ List<Row> pc3Rows =
+ Lists.newArrayList(
+ Row.withSchema(CG_SCHEMA_3).addValues("user1", 17, "us").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user1", 18, "us").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user1", 19, "il").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user1", 20, "il").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user2", 21, "fr").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user2", 22, "fr").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user2", 23, "ar").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user2", 24, "ar").build(),
+ Row.withSchema(CG_SCHEMA_3).addValues("user3", 25, "ar").build());
+
+ PCollection<Row> pc1 = pipeline.apply("Create1", Create.of(pc1Rows)).setRowSchema(CG_SCHEMA_1);
+ PCollection<Row> pc2 = pipeline.apply("Create2", Create.of(pc2Rows)).setRowSchema(CG_SCHEMA_2);
+ PCollection<Row> pc3 = pipeline.apply("Create3", Create.of(pc3Rows)).setRowSchema(CG_SCHEMA_3);
+
+ // Partial outer join. Missing entries in the "pc2" PCollection will be filled in with nulls,
+ // but not others.
+ Schema expectedSchema =
+ Schema.builder()
+ .addField("pc1", FieldType.row(CG_SCHEMA_1))
+ .addNullableField("pc2", FieldType.row(CG_SCHEMA_2))
+ .addField("pc3", FieldType.row(CG_SCHEMA_3))
+ .build();
+
+ PCollection<Row> joined =
+ PCollectionTuple.of("pc1", pc1, "pc2", pc2, "pc3", pc3)
+ .apply(
+ "CoGroup",
+ CoGroup.join("pc1", By.fieldNames("user", "country"))
+ .join("pc2", By.fieldNames("user2", "country2").withOuterJoinParticipation())
+ .join("pc3", By.fieldNames("user3", "country3"))
+ .crossProductJoin());
+ assertEquals(expectedSchema, joined.getSchema());
+
+ List<Row> expectedJoinedRows =
+ innerJoin(
+ pc1Rows,
+ pc2Rows,
+ pc3Rows,
+ new String[] {"user", "country"},
+ new String[] {"user2", "country2"},
+ new String[] {"user3", "country3"},
+ expectedSchema);
+
+ // Manually add the outer-join rows to the list of expected results. Missing results from the
+ // middle (pc2) PCollection are filled in with nulls. Missing events from other PCollections
+ // are not. Events with key ("user2", "ar) show up in pc1 and pc3 but not in pc2, so we expect
+ // the outer join to still produce those rows, with nulls for pc2. Events with key
+ // ("user3", "ar) however show up in in p2 and pc3, but not in pc1; since pc1 is marked for
+ // full participation (no outer join), these events should not be included in the join.
+ expectedJoinedRows.add(
+ Row.withSchema(expectedSchema)
+ .addValues(
+ Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(),
+ null,
+ Row.withSchema(CG_SCHEMA_3).addValues("user2", 23, "ar").build())
+ .build());
+ expectedJoinedRows.add(
+ Row.withSchema(expectedSchema)
+ .addValues(
+ Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(),
+ null,
+ Row.withSchema(CG_SCHEMA_3).addValues("user2", 24, "ar").build())
+ .build());
+ expectedJoinedRows.add(
+ Row.withSchema(expectedSchema)
+ .addValues(
+ Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build(),
+ null,
+ Row.withSchema(CG_SCHEMA_3).addValues("user2", 23, "ar").build())
+ .build());
+ expectedJoinedRows.add(
+ Row.withSchema(expectedSchema)
+ .addValues(
+ Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build(),
+ null,
+ Row.withSchema(CG_SCHEMA_3).addValues("user2", 24, "ar").build())
+ .build());
+
+ PAssert.that(joined).containsInAnyOrder(expectedJoinedRows);
+ pipeline.run();
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testUnmatchedTags() {
+ PCollection<Row> pc1 = pipeline.apply("Create1", Create.empty(CG_SCHEMA_1));
+ PCollection<Row> pc2 = pipeline.apply("Create2", Create.empty(CG_SCHEMA_2));
+
+ thrown.expect(IllegalArgumentException.class);
+
+ PCollectionTuple.of("pc1", pc1, "pc2", pc2)
+ .apply(
+ CoGroup.join("pc1", By.fieldNames("user"))
+ .join("pc3", By.fieldNames("user3"))
+ .crossProductJoin());
pipeline.run();
}