You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by re...@apache.org on 2019/03/01 15:05:57 UTC

[beam] branch master updated: Merge pull request #7353: [BEAM-4461] Support inner and outer style joins in CoGroup.

This is an automated email from the ASF dual-hosted git repository.

reuvenlax pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new c41b3c0  Merge pull request #7353: [BEAM-4461] Support inner and outer style joins in CoGroup.
c41b3c0 is described below

commit c41b3c082d924059c393345f4b1e740804d2b877
Author: reuvenlax <re...@google.com>
AuthorDate: Fri Mar 1 07:05:46 2019 -0800

    Merge pull request #7353: [BEAM-4461] Support inner and outer style joins in CoGroup.
---
 .../beam/sdk/schemas/transforms/CoGroup.java       | 612 ++++++++++++++++-----
 .../org/apache/beam/sdk/transforms/Create.java     |  13 +
 .../beam/sdk/transforms/join/CoGbkResult.java      |  17 +
 .../sdk/transforms/join/KeyedPCollectionTuple.java |  20 +
 .../apache/beam/sdk/values/PCollectionTuple.java   |  94 ++++
 .../main/java/org/apache/beam/sdk/values/Row.java  |   4 +-
 .../beam/sdk/schemas/transforms/CoGroupTest.java   | 352 +++++++++++-
 7 files changed, 932 insertions(+), 180 deletions(-)

diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java
index ebcf418..3822821 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java
@@ -17,11 +17,14 @@
  */
 package org.apache.beam.sdk.schemas.transforms;
 
+import com.google.auto.value.AutoValue;
+import java.io.Serializable;
 import java.util.Collections;
-import java.util.Comparator;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.TreeMap;
+import java.util.function.Function;
 import java.util.stream.Collectors;
 import javax.annotation.Nullable;
 import org.apache.beam.sdk.coders.KvCoder;
@@ -33,6 +36,7 @@ import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.Values;
 import org.apache.beam.sdk.transforms.join.CoGbkResult;
 import org.apache.beam.sdk.transforms.join.CoGroupByKey;
 import org.apache.beam.sdk.transforms.join.KeyedPCollectionTuple;
@@ -57,14 +61,9 @@ import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Maps;
  * <p>For example, the following demonstrates joining three PCollections on the "user" and "country"
  * fields:
  *
- * <pre>{@code TupleTag<Input1Type> input1Tag = new TupleTag<>("input1");
- * TupleTag<Input2Type> input2Tag = new TupleTag<>("input2");
- * TupleTag<Input3Type> input3Tag = new TupleTag<>("input3");
- * PCollection<KV<Row, Row>> joined = PCollectionTuple
- *     .of(input1Tag, input1)
- *     .and(input2Tag, input2)
- *     .and(input3Tag, input3)
- *   .apply(CoGroup.byFieldNames("user", "country"));
+ * <pre>{@code PCollection<KV<Row, Row>> joined =
+ *   PCollectionTuple.of("input1", input1, "input2", input2, "input3", input3)
+ *     .apply(CoGroup.join(By.fieldNames("user", "country")));
  * }</pre>
  *
  * <p>In the above case, the key schema will contain the two string fields "user" and "country"; in
@@ -107,156 +106,251 @@ import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Maps;
  * those fields match. In this case, fields must be specified for every input PCollection. For
  * example:
  *
- * <pre>{@code PCollection<KV<Row, Row>> joined = PCollectionTuple
- *     .of(input1Tag, input1)
- *     .and(input2Tag, input2)
+ * <pre>{@code PCollection<KV<Row, Row>> joined
+ *      = PCollectionTuple.of("input1Tag", input1, "input2Tag", input2)
  *   .apply(CoGroup
- *     .byFieldNames(input1Tag, "referringUser"))
- *     .byFieldNames(input2Tag, "user"));
+ *     .join("input1Tag", By.fieldNames("referringUser")))
+ *     .join("input2Tag", By.fieldNames("user")));
  * }</pre>
+ *
+ * <p>Traditional (SQL) joins are cross-product joins. All rows that match the join condition are
+ * combined into individual rows and returned; in fact any SQL inner joins is a subset of the
+ * cross-product of two tables. This transform also supports the same functionality using the {@link
+ * Inner#crossProductJoin()} method.
+ *
+ * <p>For example, consider the SQL join: SELECT * FROM input1 INNER JOIN input2 ON input1.user =
+ * input2.user
+ *
+ * <p>You could express this with:
+ *
+ * <pre>{@code
+ * PCollection<Row> joined = PCollectionTuple.of("input1", input1, "input2", input2)
+ *   .apply(CoGroup.join(By.fieldNames("user")).crossProductJoin();
+ * }</pre>
+ *
+ * <p>The schema of the output PCollection contains a nested message for each of input1 and input2.
+ * Like above, you could use the {@link Convert} transform to convert it to the following POJO:
+ *
+ * <pre>{@code
+ * {@literal @}DefaultSchema(JavaFieldSchema.class)
+ * public class JoinedValue {
+ *   public Input1Type input1;
+ *   public Input2Type input2;
+ * }
+ * }</pre>
+ *
+ * <p>The {@link Unnest} transform can then be used to flatten all the subfields into one single
+ * top-level row containing all the fields in both Input1 and Input2; this will often be combined
+ * with a {@link Select} transform to select out the fields of interest, as the key fields will be
+ * identical between input1 and input2.
+ *
+ * <p>This transform also supports outer-join semantics. By default, all input PCollections must
+ * participate fully in the join, providing inner-join semantics. This means that the join will only
+ * produce values for "Bob" if all inputs have values for "Bob;" if even a single input does not
+ * have a value for "Bob," an inner-join will produce no value. However, if you mark that input as
+ * having outer-join participation then the join will contain values for "Bob," as long as at least
+ * one input has a "Bob" value; null values will be added for inputs that have no "Bob" values. To
+ * continue the SQL example:
+ *
+ * <p>SELECT * FROM input1 LEFT OUTER JOIN input2 ON input1.user = input2.user
+ *
+ * <p>Is equivalent to:
+ *
+ * <pre>{@code
+ * PCollection<Row> joined = PCollectionTuple.of("input1", input1, "input2", input2)
+ *   .apply(CoGroup.join("input1", By.fieldNames("user").withOuterJoinParticipation())
+ *                 .join("input2", By.fieldNames("user"))
+ *                 .crossProductJoin();
+ * }</pre>
+ *
+ * <p>SELECT * FROM input1 RIGHT OUTER JOIN input2 ON input1.user = input2.user
+ *
+ * <p>Is equivalent to:
+ *
+ * <pre>{@code
+ * PCollection<Row> joined = PCollectionTuple.of("input1", input1, "input2", input2)
+ *   .apply(CoGroup.join("input1", By.fieldNames("user"))
+ *                 .join("input2", By.fieldNames("user").withOuterJoinParticipation())
+ *                 .crossProductJoin();
+ * }</pre>
+ *
+ * <p>and SELECT * FROM input1 FULL OUTER JOIN input2 ON input1.user = input2.user
+ *
+ * <p>Is equivalent to:
+ *
+ * <pre>{@code
+ * PCollection<Row> joined = PCollectionTuple.of("input1", input1, "input2", input2)
+ *   .apply(CoGroup.join("input1", By.fieldNames("user").withOuterJoinParticipation())
+ *                 .join("input2", By.fieldNames("user").withOuterJoinParticipation())
+ *                 .crossProductJoin();
+ * }</pre>
+ *
+ * <p>While the above examples use two inputs to mimic SQL's left and right join semantics, the
+ * {@link CoGroup} transform supports any number of inputs, and outer-join participation can be
+ * specified on any subset of them.
+ *
+ * <p>Do note that cross-product joins while simpler and easier to program, can cause
  */
 public class CoGroup {
-  /**
-   * Join by the following field names.
-   *
-   * <p>The same field names are used in all input PCollections.
-   */
-  public static Inner byFieldNames(String... fieldNames) {
-    return byFieldAccessDescriptor(FieldAccessDescriptor.withFieldNames(fieldNames));
-  }
+  private static final List NULL_LIST;
 
-  /**
-   * Join by the following field ids.
-   *
-   * <p>The same field ids are used in all input PCollections.
-   */
-  public static Inner byFieldIds(Integer... fieldIds) {
-    return byFieldAccessDescriptor(FieldAccessDescriptor.withFieldIds(fieldIds));
+  static {
+    NULL_LIST = Lists.newArrayList();
+    NULL_LIST.add(null);
   }
 
   /**
-   * Join by the following {@link FieldAccessDescriptor}.
-   *
-   * <p>The same access descriptor is used in all input PCollections.
+   * Defines the set of fields to extract for the join key, as well as other per-input join options.
    */
-  public static Inner byFieldAccessDescriptor(FieldAccessDescriptor fieldAccessDescriptor) {
-    return new Inner(fieldAccessDescriptor);
-  }
+  @AutoValue
+  public abstract static class By implements Serializable {
+    abstract FieldAccessDescriptor getFieldAccessDescriptor();
 
-  /**
-   * Select the following field names for the specified PCollection.
-   *
-   * <p>Each PCollection in the input must have fields specified for the join key.
-   */
-  public static Inner byFieldNames(TupleTag<?> tag, String... fieldNames) {
-    return byFieldAccessDescriptor(tag, FieldAccessDescriptor.withFieldNames(fieldNames));
-  }
+    abstract boolean getOuterJoinParticipation();
 
-  /**
-   * Select the following field ids for the specified PCollection.
-   *
-   * <p>Each PCollection in the input must have fields specified for the join key.
-   */
-  public static Inner byFieldIds(TupleTag<?> tag, Integer... fieldIds) {
-    return byFieldAccessDescriptor(tag, FieldAccessDescriptor.withFieldIds(fieldIds));
-  }
+    abstract Builder toBuilder();
 
-  /**
-   * Select the following fields for the specified PCollection using {@link FieldAccessDescriptor}.
-   *
-   * <p>Each PCollection in the input must have fields specified for the join key.
-   */
-  public static Inner byFieldAccessDescriptor(
-      TupleTag<?> tag, FieldAccessDescriptor fieldAccessDescriptor) {
-    return new Inner().byFieldAccessDescriptor(tag, fieldAccessDescriptor);
-  }
+    @AutoValue.Builder
+    abstract static class Builder {
+      abstract Builder setFieldAccessDescriptor(FieldAccessDescriptor fieldAccessDescriptor);
 
-  /** The implementing PTransform. */
-  public static class Inner extends PTransform<PCollectionTuple, PCollection<KV<Row, Row>>> {
-    @Nullable private final FieldAccessDescriptor allInputsFieldAccessDescriptor;
-    private final Map<TupleTag<?>, FieldAccessDescriptor> fieldAccessDescriptorMap;
+      abstract Builder setOuterJoinParticipation(boolean outerJoinParticipation);
 
-    private Inner() {
-      this(Collections.emptyMap());
+      abstract By build();
     }
 
-    private Inner(Map<TupleTag<?>, FieldAccessDescriptor> fieldAccessDescriptorMap) {
-      this.allInputsFieldAccessDescriptor = null;
-      this.fieldAccessDescriptorMap = fieldAccessDescriptorMap;
+    /** Join by the following field names. */
+    public static By fieldNames(String... fieldNames) {
+      return fieldAccessDescriptor(FieldAccessDescriptor.withFieldNames(fieldNames));
     }
 
-    private Inner(FieldAccessDescriptor allInputsFieldAccessDescriptor) {
-      this.allInputsFieldAccessDescriptor = allInputsFieldAccessDescriptor;
-      this.fieldAccessDescriptorMap = Collections.emptyMap();
+    /** Join by the following field ids. */
+    public static By fieldIds(Integer... fieldIds) {
+      return fieldAccessDescriptor(FieldAccessDescriptor.withFieldIds(fieldIds));
     }
 
-    /**
-     * Join by the following field names.
-     *
-     * <p>The same field names are used in all input PCollections.
-     */
-    public Inner byFieldNames(TupleTag<?> tag, String... fieldNames) {
-      return byFieldAccessDescriptor(tag, FieldAccessDescriptor.withFieldNames(fieldNames));
+    /** Join by the following field access descriptor. */
+    public static By fieldAccessDescriptor(FieldAccessDescriptor fieldAccessDescriptor) {
+      return new AutoValue_CoGroup_By.Builder()
+          .setFieldAccessDescriptor(fieldAccessDescriptor)
+          .setOuterJoinParticipation(false)
+          .build();
     }
 
     /**
-     * Select the following field ids for the specified PCollection.
+     * Means that this field will participate in a join even when not present, similar to SQL
+     * outer-join semantics. Missing entries will be replaced by nulls.
      *
-     * <p>Each PCollection in the input must have fields specified for the join key.
+     * <p>This only affects the results of expandCrossProduct.
      */
-    public Inner byFieldIds(TupleTag<?> tag, Integer... fieldIds) {
-      return byFieldAccessDescriptor(tag, FieldAccessDescriptor.withFieldIds(fieldIds));
+    public By withOuterJoinParticipation() {
+      return toBuilder().setOuterJoinParticipation(true).build();
     }
+  }
 
-    /**
-     * Select the following fields for the specified PCollection using {@link
-     * FieldAccessDescriptor}.
-     *
-     * <p>Each PCollection in the input must have fields specified for the join key.
-     */
-    public Inner byFieldAccessDescriptor(
-        TupleTag<?> tag, FieldAccessDescriptor fieldAccessDescriptor) {
-      if (allInputsFieldAccessDescriptor != null) {
-        throw new IllegalStateException("Cannot set both a global and per-tag fields.");
-      }
-      return new Inner(
-          new ImmutableMap.Builder<TupleTag<?>, FieldAccessDescriptor>()
-              .putAll(fieldAccessDescriptorMap)
-              .put(tag, fieldAccessDescriptor)
-              .build());
+  private static class JoinArguments implements Serializable {
+    @Nullable private final By allInputsJoinArgs;
+    private final Map<String, By> joinArgsMap;
+
+    JoinArguments(@Nullable By allInputsJoinArgs) {
+      this.allInputsJoinArgs = allInputsJoinArgs;
+      this.joinArgsMap = Collections.emptyMap();
+    }
+
+    JoinArguments(Map<String, By> joinArgsMap) {
+      this.allInputsJoinArgs = null;
+      this.joinArgsMap = joinArgsMap;
+    }
+
+    JoinArguments with(String tag, By clause) {
+      return new JoinArguments(
+          new ImmutableMap.Builder<String, By>().putAll(joinArgsMap).put(tag, clause).build());
     }
 
     @Nullable
-    private FieldAccessDescriptor getFieldAccessDescriptor(TupleTag<?> tag) {
-      return (allInputsFieldAccessDescriptor != null)
-          ? allInputsFieldAccessDescriptor
-          : fieldAccessDescriptorMap.get(tag);
+    private FieldAccessDescriptor getFieldAccessDescriptor(String tag) {
+      return (allInputsJoinArgs != null)
+          ? allInputsJoinArgs.getFieldAccessDescriptor()
+          : joinArgsMap.get(tag).getFieldAccessDescriptor();
     }
 
-    @Override
-    public PCollection<KV<Row, Row>> expand(PCollectionTuple input) {
+    private boolean getOuterJoinParticipation(String tag) {
+      return (allInputsJoinArgs != null)
+          ? allInputsJoinArgs.getOuterJoinParticipation()
+          : joinArgsMap.get(tag).getOuterJoinParticipation();
+    }
+  }
+
+  /**
+   * Join all input PCollections using the same args.
+   *
+   * <p>The same fields and other options are used in all input PCollections.
+   */
+  public static Inner join(By clause) {
+    return new Inner(new JoinArguments(clause));
+  }
+
+  /**
+   * Specify the following join arguments (including fields to join by_ for the specified
+   * PCollection.
+   *
+   * <p>Each PCollection in the input must have args specified for the join key.
+   */
+  public static Inner join(String tag, By clause) {
+    return new Inner(new JoinArguments(ImmutableMap.of(tag, clause)));
+  }
+
+  // Contains summary information needed for implementing the join.
+  private static class JoinInformation {
+    private final KeyedPCollectionTuple<Row> keyedPCollectionTuple;
+    private final Schema keySchema;
+    private final Map<String, Schema> componentSchemas;
+    // Maps from index in sortedTags to the toRow function.
+    private final Map<Integer, SerializableFunction<Object, Row>> toRows;
+    private final List<String> sortedTags;
+    private final Map<Integer, String> tagToKeyedTag;
+
+    private JoinInformation(
+        KeyedPCollectionTuple<Row> keyedPCollectionTuple,
+        Schema keySchema,
+        Map<String, Schema> componentSchemas,
+        Map<Integer, SerializableFunction<Object, Row>> toRows,
+        List<String> sortedTags,
+        Map<Integer, String> tagToKeyedTag) {
+      this.keyedPCollectionTuple = keyedPCollectionTuple;
+      this.keySchema = keySchema;
+      this.componentSchemas = componentSchemas;
+      this.toRows = toRows;
+      this.sortedTags = sortedTags;
+      this.tagToKeyedTag = tagToKeyedTag;
+    }
+
+    private static JoinInformation from(
+        PCollectionTuple input, Function<String, FieldAccessDescriptor> getFieldAccessDescriptor) {
       KeyedPCollectionTuple<Row> keyedPCollectionTuple =
           KeyedPCollectionTuple.empty(input.getPipeline());
-      List<TupleTag<Row>> sortedTags =
+
+      List<String> sortedTags =
           input.getAll().keySet().stream()
-              .sorted(Comparator.comparing(TupleTag::getId))
-              .map(t -> new TupleTag<Row>(t.getId() + "_ROW"))
+              .map(TupleTag::getId)
+              .sorted()
               .collect(Collectors.toList());
 
       // Keep this in a TreeMap so that it's sorted. This way we get a deterministic output
       // schema.
       TreeMap<String, Schema> componentSchemas = Maps.newTreeMap();
-      Map<String, SerializableFunction<Object, Row>> toRows = Maps.newHashMap();
+      Map<Integer, SerializableFunction<Object, Row>> toRows = Maps.newHashMap();
 
+      Map<Integer, String> tagToKeyedTag = Maps.newHashMap();
       Schema keySchema = null;
       for (Map.Entry<TupleTag<?>, PCollection<?>> entry : input.getAll().entrySet()) {
-        TupleTag<?> tag = entry.getKey();
+        String tag = entry.getKey().getId();
+        int tagIndex = sortedTags.indexOf(tag);
         PCollection<?> pc = entry.getValue();
         Schema schema = pc.getSchema();
-        componentSchemas.put(tag.getId(), schema);
-        TupleTag<Row> rowTag = new TupleTag<>(tag.getId() + "_ROW");
-        toRows.put(rowTag.getId(), (SerializableFunction<Object, Row>) pc.getToRowFunction());
-        FieldAccessDescriptor fieldAccessDescriptor = getFieldAccessDescriptor(tag);
+        componentSchemas.put(tag, schema);
+        toRows.put(tagIndex, (SerializableFunction<Object, Row>) pc.getToRowFunction());
+        FieldAccessDescriptor fieldAccessDescriptor = getFieldAccessDescriptor.apply(tag);
         if (fieldAccessDescriptor == null) {
           throw new IllegalStateException("No fields were set for input " + tag);
         }
@@ -275,51 +369,150 @@ public class CoGroup {
           }
         }
 
+        // Create a new tag for the output.
+        TupleTag randomTag = new TupleTag<>();
+        String keyedTag = tag + "_" + randomTag;
+        tagToKeyedTag.put(tagIndex, keyedTag);
         PCollection<KV<Row, Row>> keyedPCollection =
-            extractKey(pc, schema, keySchema, resolved, tag.getId());
-        keyedPCollectionTuple = keyedPCollectionTuple.and(rowTag, keyedPCollection);
+            extractKey(pc, schema, keySchema, resolved, tag);
+        keyedPCollectionTuple = keyedPCollectionTuple.and(keyedTag, keyedPCollection);
       }
+      return new JoinInformation(
+          keyedPCollectionTuple, keySchema, componentSchemas, toRows, sortedTags, tagToKeyedTag);
+    }
 
+    private static <T> PCollection<KV<Row, Row>> extractKey(
+        PCollection<T> pCollection,
+        Schema schema,
+        Schema keySchema,
+        FieldAccessDescriptor keyFields,
+        String tag) {
+      return pCollection
+          .apply(
+              "extractKey" + tag,
+              ParDo.of(
+                  new DoFn<T, KV<Row, Row>>() {
+                    @ProcessElement
+                    public void process(@Element Row row, OutputReceiver<KV<Row, Row>> o) {
+                      o.output(KV.of(Select.selectRow(row, keyFields, schema, keySchema), row));
+                    }
+                  }))
+          .setCoder(KvCoder.of(SchemaCoder.of(keySchema), SchemaCoder.of(schema)));
+    }
+  }
+
+  static void verify(PCollectionTuple input, JoinArguments joinArgs) {
+    if (joinArgs.allInputsJoinArgs == null) {
+      // If explicit join tags were specified, then they must match the input tuple.
+      Set<String> inputTags =
+          input.getAll().keySet().stream().map(TupleTag::getId).collect(Collectors.toSet());
+      Set<String> joinTags = joinArgs.joinArgsMap.keySet();
+      if (!inputTags.equals(joinTags)) {
+        throw new IllegalArgumentException(
+            "The input PCollectionTuple has tags: "
+                + inputTags
+                + " and the join was specified for tags "
+                + joinTags
+                + ". These do not match.");
+      }
+    }
+  }
+
+  /** The implementing PTransform. */
+  public static class Inner extends PTransform<PCollectionTuple, PCollection<KV<Row, Row>>> {
+    private final JoinArguments joinArgs;
+
+    private Inner() {
+      this(new JoinArguments(Collections.emptyMap()));
+    }
+
+    private Inner(JoinArguments joinArgs) {
+      this.joinArgs = joinArgs;
+    }
+
+    /**
+     * Select the following fields for the specified PCollection with the specified join args.
+     *
+     * <p>Each PCollection in the input must have fields specified for the join key.
+     */
+    public Inner join(String tag, By clause) {
+      if (joinArgs.allInputsJoinArgs != null) {
+        throw new IllegalStateException("Cannot set both a global and per-tag fields.");
+      }
+      return new Inner(joinArgs.with(tag, clause));
+    }
+
+    /** Expand the join into individual rows, similar to SQL joins. */
+    public ExpandCrossProduct crossProductJoin() {
+      return new ExpandCrossProduct(joinArgs);
+    }
+
+    private Schema getOutputSchema(JoinInformation joinInformation) {
       // Construct the output schema. It contains one field for each input PCollection, of type
       // ARRAY[ROW].
       Schema.Builder joinedSchemaBuilder = Schema.builder();
-      for (Map.Entry<String, Schema> entry : componentSchemas.entrySet()) {
+      for (Map.Entry<String, Schema> entry : joinInformation.componentSchemas.entrySet()) {
         joinedSchemaBuilder.addArrayField(entry.getKey(), FieldType.row(entry.getValue()));
       }
-      Schema joinedSchema = joinedSchemaBuilder.build();
+      return joinedSchemaBuilder.build();
+    }
+
+    @Override
+    public PCollection<KV<Row, Row>> expand(PCollectionTuple input) {
+      verify(input, joinArgs);
 
-      return keyedPCollectionTuple
+      JoinInformation joinInformation =
+          JoinInformation.from(input, joinArgs::getFieldAccessDescriptor);
+
+      Schema joinedSchema = getOutputSchema(joinInformation);
+
+      return joinInformation
+          .keyedPCollectionTuple
           .apply("CoGroupByKey", CoGroupByKey.create())
-          .apply("ConvertToRow", ParDo.of(new ConvertToRow(sortedTags, toRows, joinedSchema)))
-          .setCoder(KvCoder.of(SchemaCoder.of(keySchema), SchemaCoder.of(joinedSchema)));
+          .apply(
+              "ConvertToRow",
+              ParDo.of(
+                  new ConvertToRow(
+                      joinInformation.sortedTags,
+                      joinInformation.toRows,
+                      joinedSchema,
+                      joinInformation.tagToKeyedTag)))
+          .setCoder(
+              KvCoder.of(SchemaCoder.of(joinInformation.keySchema), SchemaCoder.of(joinedSchema)));
     }
 
+    // Used by the unexpanded join to create the output rows.
     private static class ConvertToRow extends DoFn<KV<Row, CoGbkResult>, KV<Row, Row>> {
-      List<TupleTag<Row>> sortedTags;
-      Map<String, SerializableFunction<Object, Row>> toRows = Maps.newHashMap();
-      Schema joinedSchema;
-
-      public ConvertToRow(
-          List<TupleTag<Row>> sortedTags,
-          Map<String, SerializableFunction<Object, Row>> toRows,
-          Schema joinedSchema) {
+      private final List<String> sortedTags;
+      private final Map<Integer, SerializableFunction<Object, Row>> toRows;
+      private final Map<Integer, String> tagToKeyedTag;
+      private final Schema joinedSchema;
+
+      ConvertToRow(
+          List<String> sortedTags,
+          Map<Integer, SerializableFunction<Object, Row>> toRows,
+          Schema joinedSchema,
+          Map<Integer, String> tagToKeyedTag) {
         this.sortedTags = sortedTags;
         this.toRows = toRows;
         this.joinedSchema = joinedSchema;
+        this.tagToKeyedTag = tagToKeyedTag;
       }
 
       @ProcessElement
       public void process(@Element KV<Row, CoGbkResult> kv, OutputReceiver<KV<Row, Row>> o) {
         Row key = kv.getKey();
         CoGbkResult result = kv.getValue();
-        List<Object> fields = Lists.newArrayListWithExpectedSize(sortedTags.size());
-        for (TupleTag<?> tag : sortedTags) {
+        List<Object> fields = Lists.newArrayListWithCapacity(sortedTags.size());
+        for (int i = 0; i < sortedTags.size(); ++i) {
+          String tag = sortedTags.get(i);
           // TODO: This forces the entire join to materialize in memory. We should create a
           // lazy Row interface on top of the iterable returned by CoGbkResult. This will
-          // allow the data to be streamed in.
-          SerializableFunction<Object, Row> toRow = toRows.get(tag.getId());
+          // allow the data to be streamed in. Tracked in [BEAM-6756].
+          SerializableFunction<Object, Row> toRow = toRows.get(i);
+          String tupleTag = tagToKeyedTag.get(i);
           List<Row> joined = Lists.newArrayList();
-          for (Object item : result.getAll(tag)) {
+          for (Object item : result.getAll(tupleTag)) {
             joined.add(toRow.apply(item));
           }
           fields.add(joined);
@@ -327,24 +520,145 @@ public class CoGroup {
         o.output(KV.of(key, Row.withSchema(joinedSchema).addValues(fields).build()));
       }
     }
+  }
 
-    private static <T> PCollection<KV<Row, Row>> extractKey(
-        PCollection<T> pCollection,
-        Schema schema,
-        Schema keySchema,
-        FieldAccessDescriptor keyFields,
-        String tag) {
-      return pCollection
+  /** A {@link PTransform} that calculates the cross-product join. */
+  public static class ExpandCrossProduct extends PTransform<PCollectionTuple, PCollection<Row>> {
+    private final JoinArguments joinArgs;
+
+    ExpandCrossProduct(JoinArguments joinArgs) {
+      this.joinArgs = joinArgs;
+    }
+
+    /**
+     * Select the following fields for the specified PCollection with the specified join args.
+     *
+     * <p>Each PCollection in the input must have fields specified for the join key.
+     */
+    public ExpandCrossProduct join(String tag, By clause) {
+      if (joinArgs.allInputsJoinArgs != null) {
+        throw new IllegalStateException("Cannot set both a global and per-tag fields.");
+      }
+      return new ExpandCrossProduct(joinArgs.with(tag, clause));
+    }
+
+    private Schema getOutputSchema(JoinInformation joinInformation) {
+      // Construct the output schema. It contains one field for each input PCollection, of type
+      // ROW. If a field supports outer-join semantics, then that field will be nullable in the
+      // schema.
+      Schema.Builder joinedSchemaBuilder = Schema.builder();
+      for (Map.Entry<String, Schema> entry : joinInformation.componentSchemas.entrySet()) {
+        FieldType fieldType = FieldType.row(entry.getValue());
+        if (joinArgs.getOuterJoinParticipation(entry.getKey())) {
+          fieldType = fieldType.withNullable(true);
+        }
+        joinedSchemaBuilder.addField(entry.getKey(), fieldType);
+      }
+      return joinedSchemaBuilder.build();
+    }
+
+    @Override
+    public PCollection<Row> expand(PCollectionTuple input) {
+      verify(input, joinArgs);
+
+      JoinInformation joinInformation =
+          JoinInformation.from(input, joinArgs::getFieldAccessDescriptor);
+
+      Schema joinedSchema = getOutputSchema(joinInformation);
+
+      return joinInformation
+          .keyedPCollectionTuple
+          .apply("CoGroupByKey", CoGroupByKey.create())
+          .apply("Values", Values.create())
           .apply(
-              "extractKey" + tag,
+              "ExpandToRow",
               ParDo.of(
-                  new DoFn<T, KV<Row, Row>>() {
-                    @ProcessElement
-                    public void process(@Element Row row, OutputReceiver<KV<Row, Row>> o) {
-                      o.output(KV.of(Select.selectRow(row, keyFields, schema, keySchema), row));
-                    }
-                  }))
-          .setCoder(KvCoder.of(SchemaCoder.of(keySchema), SchemaCoder.of(schema)));
+                  new ExpandToRows(
+                      joinInformation.sortedTags,
+                      joinInformation.toRows,
+                      joinedSchema,
+                      joinInformation.tagToKeyedTag)))
+          .setRowSchema(joinedSchema);
+    }
+
+    /** A DoFn that expands the result of a CoGroupByKey into the cross product. */
+    private class ExpandToRows extends DoFn<CoGbkResult, Row> {
+      private final List<String> sortedTags;
+      private final Map<Integer, SerializableFunction<Object, Row>> toRows;
+      private final Schema outputSchema;
+      private final Map<Integer, String> tagToKeyedTag;
+
+      public ExpandToRows(
+          List<String> sortedTags,
+          Map<Integer, SerializableFunction<Object, Row>> toRows,
+          Schema outputSchema,
+          Map<Integer, String> tagToKeyedTag) {
+        this.sortedTags = sortedTags;
+        this.toRows = toRows;
+        this.outputSchema = outputSchema;
+        this.tagToKeyedTag = tagToKeyedTag;
+      }
+
+      @ProcessElement
+      public void process(@Element CoGbkResult gbkResult, OutputReceiver<Row> o) {
+        List<Iterable> allIterables = extractIterables(gbkResult);
+        List<Row> accumulatedRows = Lists.newArrayListWithCapacity(sortedTags.size());
+        crossProduct(0, accumulatedRows, allIterables, o);
+      }
+
+      private List<Iterable> extractIterables(CoGbkResult gbkResult) {
+        List<Iterable> iterables = Lists.newArrayListWithCapacity(sortedTags.size());
+        for (int i = 0; i < sortedTags.size(); ++i) {
+          String tag = sortedTags.get(i);
+          Iterable items = gbkResult.getAll(tagToKeyedTag.get(i));
+          if (!items.iterator().hasNext() && joinArgs.getOuterJoinParticipation(tag)) {
+            // If this tag has outer-join participation, then empty should participate as a
+            // single null.
+            items = () -> NULL_LIST.iterator();
+          }
+          iterables.add(items);
+        }
+        return iterables;
+      }
+
+      private void crossProduct(
+          int tagIndex,
+          List<Row> accumulatedRows,
+          List<Iterable> iterables,
+          OutputReceiver<Row> o) {
+        if (tagIndex >= sortedTags.size()) {
+          return;
+        }
+
+        SerializableFunction<Object, Row> toRow = toRows.get(tagIndex);
+        for (Object item : iterables.get(tagIndex)) {
+          // For every item that joined for the current input, and recurse down to calculate the
+          // list of expanded records.
+          Row row = toRow.apply(item);
+          crossProductHelper(tagIndex, accumulatedRows, row, iterables, o);
+        }
+      }
+
+      private void crossProductHelper(
+          int tagIndex,
+          List<Row> accumulatedRows,
+          Row newRow,
+          List<Iterable> iterables,
+          OutputReceiver<Row> o) {
+        boolean atBottom = tagIndex == sortedTags.size() - 1;
+        accumulatedRows.add(newRow);
+        if (atBottom) {
+          // Bottom of recursive call, so output the row we've accumulated.
+          o.output(buildOutputRow(accumulatedRows));
+        } else {
+          crossProduct(tagIndex + 1, accumulatedRows, iterables, o);
+        }
+        accumulatedRows.remove(accumulatedRows.size() - 1);
+      }
+
+      private Row buildOutputRow(List rows) {
+        return Row.withSchema(outputSchema).addValues(rows).build();
+      }
     }
   }
 }
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java
index 1ed6a28..3496558 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java
@@ -146,6 +146,19 @@ public class Create<T> {
   }
 
   /**
+   * Returns a new {@code Create.Values} transform that produces an empty {@link PCollection} of
+   * rows.
+   */
+  public static Values<Row> empty(Schema schema) {
+    return new Values<Row>(
+        new ArrayList<>(),
+        Optional.of(
+            SchemaCoder.of(
+                schema, SerializableFunctions.identity(), SerializableFunctions.identity())),
+        Optional.absent());
+  }
+
+  /**
    * Returns a new {@code Create.Values} transform that produces an empty {@link PCollection}.
    *
    * <p>The elements will have a timestamp of negative infinity, see {@link Create#timestamped} for
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/CoGbkResult.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/CoGbkResult.java
index 27dc405..df97041 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/CoGbkResult.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/CoGbkResult.java
@@ -167,6 +167,11 @@ public class CoGbkResult {
     return unions;
   }
 
+  /** Like {@link #getAll(TupleTag)} but using a String instead of a {@link TupleTag}. */
+  public <V> Iterable<V> getAll(String tag) {
+    return getAll(new TupleTag<>(tag));
+  }
+
   /**
    * If there is a singleton value for the given tag, returns it. Otherwise, throws an
    * IllegalArgumentException.
@@ -178,6 +183,12 @@ public class CoGbkResult {
     return innerGetOnly(tag, null, false);
   }
 
+  /** Like {@link #getOnly(TupleTag)} but using a String instead of a TupleTag. */
+  @SuppressWarnings("TypeParameterUnusedInFormals")
+  public <V> V getOnly(String tag) {
+    return getOnly(new TupleTag<>(tag));
+  }
+
   /**
    * If there is a singleton value for the given tag, returns it. If there is no value for the given
    * tag, returns the defaultValue.
@@ -190,6 +201,12 @@ public class CoGbkResult {
     return innerGetOnly(tag, defaultValue, true);
   }
 
+  /** Like {@link #getOnly(TupleTag, Object)} but using a String instead of a TupleTag. */
+  @Nullable
+  public <V> V getOnly(String tag, @Nullable V defaultValue) {
+    return getOnly(new TupleTag<>(tag), defaultValue);
+  }
+
   /** A {@link Coder} for {@link CoGbkResult}s. */
   public static class CoGbkResultCoder extends CustomCoder<CoGbkResult> {
 
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/KeyedPCollectionTuple.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/KeyedPCollectionTuple.java
index 5ebc1d5..7bb2781 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/KeyedPCollectionTuple.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/KeyedPCollectionTuple.java
@@ -53,6 +53,16 @@ public class KeyedPCollectionTuple<K> implements PInput {
   }
 
   /**
+   * A version of {@link #of(TupleTag, PCollection)} that takes in a string instead of a TupleTag.
+   *
+   * <p>This method is simpler for cases when a typed tuple-tag is not needed to extract a
+   * PCollection, for example when using schema transforms.
+   */
+  public static <K, InputT> KeyedPCollectionTuple<K> of(String tag, PCollection<KV<K, InputT>> pc) {
+    return of(new TupleTag<>(tag), pc);
+  }
+
+  /**
    * Returns a new {@code KeyedPCollectionTuple<K>} that is the same as this, appended with the
    * given PCollection.
    */
@@ -67,6 +77,16 @@ public class KeyedPCollectionTuple<K> implements PInput {
         getPipeline(), newKeyedCollections, schema.getTupleTagList().and(tag), myKeyCoder);
   }
 
+  /**
+   * A version of {@link #and(String, PCollection)} that takes in a string instead of a TupleTag.
+   *
+   * <p>This method is simpler for cases when a typed tuple-tag is not needed to extract a
+   * PCollection, for example when using schema transforms.
+   */
+  public <V> KeyedPCollectionTuple<K> and(String tag, PCollection<KV<K, V>> pc) {
+    return and(new TupleTag<>(tag), pc);
+  }
+
   public boolean isEmpty() {
     return keyedCollections.isEmpty();
   }
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java
index c40d683..92fe0ee 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java
@@ -92,6 +92,73 @@ public class PCollectionTuple implements PInput, POutput {
   }
 
   /**
+   * A version of {@link #of(TupleTag, PCollection)} that takes in a String instead of a {@link
+   * TupleTag}.
+   *
+   * <p>This method is simpler for cases when a typed tuple-tag is not needed to extract a
+   * PCollection, for example when using schema transforms.
+   */
+  public static <T> PCollectionTuple of(String tag, PCollection<T> pc) {
+    return of(new TupleTag<>(tag), pc);
+  }
+
+  /**
+   * A version of {@link #of(String, PCollection)} that takes in two PCollections of the same type.
+   */
+  public static <T> PCollectionTuple of(
+      String tag1, PCollection<T> pc1, String tag2, PCollection<T> pc2) {
+    return of(tag1, pc1).and(tag2, pc2);
+  }
+
+  /**
+   * A version of {@link #of(String, PCollection)} that takes in three PCollections of the same
+   * type.
+   */
+  public static <T> PCollectionTuple of(
+      String tag1,
+      PCollection<T> pc1,
+      String tag2,
+      PCollection<T> pc2,
+      String tag3,
+      PCollection<T> pc3) {
+    return of(tag1, pc1, tag2, pc2).and(tag3, pc3);
+  }
+
+  /**
+   * A version of {@link #of(String, PCollection)} that takes in four PCollections of the same type.
+   */
+  public static <T> PCollectionTuple of(
+      String tag1,
+      PCollection<T> pc1,
+      String tag2,
+      PCollection<T> pc2,
+      String tag3,
+      PCollection<T> pc3,
+      String tag4,
+      PCollection<T> pc4) {
+    return of(tag1, pc1, tag2, pc2, tag3, pc3).and(tag4, pc4);
+  }
+
+  /**
+   * A version of {@link #of(String, PCollection)} that takes in five PCollections of the same type.
+   */
+  public static <T> PCollectionTuple of(
+      String tag1,
+      PCollection<T> pc1,
+      String tag2,
+      PCollection<T> pc2,
+      String tag3,
+      PCollection<T> pc3,
+      String tag4,
+      PCollection<T> pc4,
+      String tag5,
+      PCollection<T> pc5) {
+    return of(tag1, pc1, tag2, pc2, tag3, pc3, tag4, pc4).and(tag5, pc5);
+  }
+
+  // To create a PCollectionTuple with more than five inputs, use the and() builder method.
+
+  /**
    * Returns a new {@link PCollectionTuple} that has each {@link PCollection} and {@link TupleTag}
    * of this {@link PCollectionTuple} plus the given {@link PCollection} associated with the given
    * {@link TupleTag}.
@@ -116,6 +183,16 @@ public class PCollectionTuple implements PInput, POutput {
   }
 
   /**
+   * A version of {@link #and(TupleTag, PCollection)} that takes in a String instead of a TupleTag.
+   *
+   * <p>This method is simpler for cases when a typed tuple-tag is not needed to extract a
+   * PCollection, for example when using schema transforms.
+   */
+  public <T> PCollectionTuple and(String tag, PCollection<T> pc) {
+    return and(new TupleTag<>(tag), pc);
+  }
+
+  /**
    * Returns whether this {@link PCollectionTuple} contains a {@link PCollection} with the given
    * tag.
    */
@@ -124,6 +201,14 @@ public class PCollectionTuple implements PInput, POutput {
   }
 
   /**
+   * Returns whether this {@link PCollectionTuple} contains a {@link PCollection} with the given
+   * tag.
+   */
+  public <T> boolean has(String tag) {
+    return has(new TupleTag<>(tag));
+  }
+
+  /**
    * Returns the {@link PCollection} associated with the given {@link TupleTag} in this {@link
    * PCollectionTuple}. Throws {@link IllegalArgumentException} if there is no such {@link
    * PCollection}, i.e., {@code !has(tag)}.
@@ -138,6 +223,15 @@ public class PCollectionTuple implements PInput, POutput {
   }
 
   /**
+   * Returns the {@link PCollection} associated with the given tag in this {@link PCollectionTuple}.
+   * Throws {@link IllegalArgumentException} if there is no such {@link PCollection}, i.e., {@code
+   * !has(tag)}.
+   */
+  public <T> PCollection<T> get(String tag) {
+    return get(new TupleTag<>(tag));
+  }
+
+  /**
    * Returns an immutable Map from {@link TupleTag} to corresponding {@link PCollection}, for all
    * the members of this {@link PCollectionTuple}.
    */
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java
index ac11462..47958e2 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java
@@ -544,8 +544,8 @@ public abstract class Row implements Serializable {
       if (schema.getFieldCount() != values.size()) {
         throw new IllegalArgumentException(
             String.format(
-                "Field count in Schema (%s) and values (%s) must match",
-                schema.getFieldNames(), values));
+                "Field count in Schema (%s) (%d) and values (%s) (%d)  must match",
+                schema.getFieldNames(), schema.getFieldCount(), values, values.size()));
       }
       for (int i = 0; i < values.size(); ++i) {
         Object value = values.get(i);
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java
index 35bd657..2251c6c 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java
@@ -23,11 +23,14 @@ import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder;
 import static org.junit.Assert.assertThat;
 
+import java.util.Arrays;
 import java.util.List;
+import java.util.stream.Collectors;
 import org.apache.beam.sdk.TestUtils.KvMatcher;
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.schemas.Schema.FieldType;
 import org.apache.beam.sdk.schemas.Schema.TypeName;
+import org.apache.beam.sdk.schemas.transforms.CoGroup.By;
 import org.apache.beam.sdk.testing.NeedsRunner;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
@@ -36,7 +39,6 @@ import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.Row;
-import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList;
 import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Lists;
 import org.hamcrest.BaseMatcher;
@@ -183,10 +185,8 @@ public class CoGroupTest {
             .build();
 
     PCollection<KV<Row, Row>> joined =
-        PCollectionTuple.of(new TupleTag<>("pc1"), pc1)
-            .and(new TupleTag<>("pc2"), pc2)
-            .and(new TupleTag<>("pc3"), pc3)
-            .apply("CoGroup", CoGroup.byFieldNames("user", "country"));
+        PCollectionTuple.of("pc1", pc1, "pc2", pc2, "pc3", pc3)
+            .apply("CoGroup", CoGroup.join(By.fieldNames("user", "country")));
     List<KV<Row, Row>> expected =
         ImmutableList.of(
             KV.of(key1, key1Joined),
@@ -325,19 +325,13 @@ public class CoGroupTest {
                     Row.withSchema(CG_SCHEMA_3).addValues("user2", 24, "ar").build()))
             .build();
 
-    TupleTag<Row> pc1Tag = new TupleTag<>("pc1");
-    TupleTag<Row> pc2Tag = new TupleTag<>("pc2");
-    TupleTag<Row> pc3Tag = new TupleTag<>("pc3");
-
     PCollection<KV<Row, Row>> joined =
-        PCollectionTuple.of(pc1Tag, pc1)
-            .and(pc2Tag, pc2)
-            .and(pc3Tag, pc3)
+        PCollectionTuple.of("pc1", pc1, "pc2", pc2, "pc3", pc3)
             .apply(
                 "CoGroup",
-                CoGroup.byFieldNames(pc1Tag, "user", "country")
-                    .byFieldNames(pc2Tag, "user2", "country2")
-                    .byFieldNames(pc3Tag, "user3", "country3"));
+                CoGroup.join("pc1", By.fieldNames("user", "country"))
+                    .join("pc2", By.fieldNames("user2", "country2"))
+                    .join("pc3", By.fieldNames("user3", "country3")));
 
     List<KV<Row, Row>> expected =
         ImmutableList.of(
@@ -367,19 +361,14 @@ public class CoGroupTest {
     PCollection<Row> pc3 =
         pipeline.apply(
             "Create3", Create.of(Row.withSchema(CG_SCHEMA_3).addValues("user1", 17, "us").build()));
-    TupleTag<Row> pc1Tag = new TupleTag<>("pc1");
-    TupleTag<Row> pc2Tag = new TupleTag<>("pc2");
-    TupleTag<Row> pc3Tag = new TupleTag<>("pc3");
 
-    thrown.expect(IllegalStateException.class);
+    thrown.expect(IllegalArgumentException.class);
     PCollection<KV<Row, Row>> joined =
-        PCollectionTuple.of(pc1Tag, pc1)
-            .and(pc2Tag, pc2)
-            .and(pc3Tag, pc3)
+        PCollectionTuple.of("pc1", pc1, "pc2", pc2, "pc3", pc3)
             .apply(
                 "CoGroup",
-                CoGroup.byFieldNames(pc1Tag, "user", "country")
-                    .byFieldNames(pc2Tag, "user2", "country2"));
+                CoGroup.join("pc1", By.fieldNames("user", "country"))
+                    .join("pc2", By.fieldNames("user2", "country2")));
     pipeline.run();
   }
 
@@ -399,13 +388,318 @@ public class CoGroupTest {
                 Create.of(Row.withSchema(CG_SCHEMA_1).addValues("user1", 9, "us").build()))
             .setRowSchema(CG_SCHEMA_1);
 
-    TupleTag<Row> pc1Tag = new TupleTag<>("pc1");
-    TupleTag<Row> pc2Tag = new TupleTag<>("pc2");
     thrown.expect(IllegalStateException.class);
     PCollection<KV<Row, Row>> joined =
-        PCollectionTuple.of(pc1Tag, pc1)
-            .and(pc2Tag, pc2)
-            .apply("CoGroup", CoGroup.byFieldNames(pc1Tag, "user").byFieldNames(pc2Tag, "count"));
+        PCollectionTuple.of("pc1", pc1, "pc2", pc2)
+            .apply(
+                "CoGroup",
+                CoGroup.join("pc1", By.fieldNames("user")).join("pc2", By.fieldNames("count")));
+    pipeline.run();
+  }
+
+  private List<Row> innerJoin(
+      List<Row> inputs1,
+      List<Row> inputs2,
+      List<Row> inputs3,
+      String[] keys1,
+      String[] keys2,
+      String[] keys3,
+      Schema expectedSchema) {
+    List<Row> joined = Lists.newArrayList();
+    for (Row row1 : inputs1) {
+      for (Row row2 : inputs2) {
+        for (Row row3 : inputs3) {
+          List key1 = Arrays.stream(keys1).map(row1::getValue).collect(Collectors.toList());
+          List key2 = Arrays.stream(keys2).map(row2::getValue).collect(Collectors.toList());
+          List key3 = Arrays.stream(keys3).map(row3::getValue).collect(Collectors.toList());
+          if (key1.equals(key2) && key2.equals(key3)) {
+            joined.add(Row.withSchema(expectedSchema).addValues(row1, row2, row3).build());
+          }
+        }
+      }
+    }
+    return joined;
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testInnerJoin() {
+    List<Row> pc1Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 1, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 2, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 3, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 4, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 5, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 6, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build());
+    List<Row> pc2Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 9, "us").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 10, "us").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 11, "il").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 12, "il").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user2", 13, "fr").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user2", 14, "fr").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user2", 15, "ar").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user2", 16, "ar").build());
+    List<Row> pc3Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_3).addValues("user1", 17, "us").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user1", 18, "us").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user1", 19, "il").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user1", 20, "il").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user2", 21, "fr").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user2", 22, "fr").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user2", 23, "ar").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user2", 24, "ar").build());
+
+    PCollection<Row> pc1 = pipeline.apply("Create1", Create.of(pc1Rows)).setRowSchema(CG_SCHEMA_1);
+    PCollection<Row> pc2 = pipeline.apply("Create2", Create.of(pc2Rows)).setRowSchema(CG_SCHEMA_2);
+    PCollection<Row> pc3 = pipeline.apply("Create3", Create.of(pc3Rows)).setRowSchema(CG_SCHEMA_3);
+
+    Schema expectedSchema =
+        Schema.builder()
+            .addRowField("pc1", CG_SCHEMA_1)
+            .addRowField("pc2", CG_SCHEMA_2)
+            .addRowField("pc3", CG_SCHEMA_3)
+            .build();
+
+    PCollection<Row> joined =
+        PCollectionTuple.of("pc1", pc1, "pc2", pc2, "pc3", pc3)
+            .apply(
+                "CoGroup",
+                CoGroup.join("pc1", By.fieldNames("user", "country"))
+                    .join("pc2", By.fieldNames("user2", "country2"))
+                    .join("pc3", By.fieldNames("user3", "country3"))
+                    .crossProductJoin());
+    assertEquals(expectedSchema, joined.getSchema());
+
+    List<Row> expectedJoinedRows =
+        innerJoin(
+            pc1Rows,
+            pc2Rows,
+            pc3Rows,
+            new String[] {"user", "country"},
+            new String[] {"user2", "country2"},
+            new String[] {"user3", "country3"},
+            expectedSchema);
+
+    PAssert.that(joined).containsInAnyOrder(expectedJoinedRows);
+    pipeline.run();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testFullOuterJoin() {
+    List<Row> pc1Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 1, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 2, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 3, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 4, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 5, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 6, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user3", 7, "ar").build());
+
+    List<Row> pc2Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 9, "us").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 10, "us").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 11, "il").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 12, "il").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user2", 13, "fr").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user2", 14, "fr").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user2", 15, "ar").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user2", 16, "ar").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user2", 16, "es").build());
+
+    List<Row> pc3Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_3).addValues("user1", 17, "us").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user1", 18, "us").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user1", 19, "il").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user1", 20, "il").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user2", 21, "fr").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user2", 22, "fr").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user2", 23, "ar").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user2", 24, "ar").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user27", 24, "se").build());
+
+    PCollection<Row> pc1 = pipeline.apply("Create1", Create.of(pc1Rows)).setRowSchema(CG_SCHEMA_1);
+    PCollection<Row> pc2 = pipeline.apply("Create2", Create.of(pc2Rows)).setRowSchema(CG_SCHEMA_2);
+    PCollection<Row> pc3 = pipeline.apply("Create3", Create.of(pc3Rows)).setRowSchema(CG_SCHEMA_3);
+
+    // Full outer join, so any field might be null.
+    Schema expectedSchema =
+        Schema.builder()
+            .addNullableField("pc1", FieldType.row(CG_SCHEMA_1))
+            .addNullableField("pc2", FieldType.row(CG_SCHEMA_2))
+            .addNullableField("pc3", FieldType.row(CG_SCHEMA_3))
+            .build();
+
+    PCollection<Row> joined =
+        PCollectionTuple.of("pc1", pc1, "pc2", pc2, "pc3", pc3)
+            .apply(
+                "CoGroup",
+                CoGroup.join("pc1", By.fieldNames("user", "country").withOuterJoinParticipation())
+                    .join("pc2", By.fieldNames("user2", "country2").withOuterJoinParticipation())
+                    .join("pc3", By.fieldNames("user3", "country3").withOuterJoinParticipation())
+                    .crossProductJoin());
+    assertEquals(expectedSchema, joined.getSchema());
+
+    List<Row> expectedJoinedRows =
+        innerJoin(
+            pc1Rows,
+            pc2Rows,
+            pc3Rows,
+            new String[] {"user", "country"},
+            new String[] {"user2", "country2"},
+            new String[] {"user3", "country3"},
+            expectedSchema);
+    // Manually add the outer-join rows to the list of expected results.
+    expectedJoinedRows.add(
+        Row.withSchema(expectedSchema)
+            .addValues(Row.withSchema(CG_SCHEMA_1).addValues("user3", 7, "ar").build(), null, null)
+            .build());
+    expectedJoinedRows.add(
+        Row.withSchema(expectedSchema)
+            .addValues(null, Row.withSchema(CG_SCHEMA_2).addValues("user2", 16, "es").build(), null)
+            .build());
+    expectedJoinedRows.add(
+        Row.withSchema(expectedSchema)
+            .addValues(
+                null, null, Row.withSchema(CG_SCHEMA_3).addValues("user27", 24, "se").build())
+            .build());
+
+    PAssert.that(joined).containsInAnyOrder(expectedJoinedRows);
+    pipeline.run();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testPartialOuterJoin() {
+    List<Row> pc1Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 1, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 2, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 3, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 4, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 5, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 6, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build());
+
+    List<Row> pc2Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 9, "us").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 10, "us").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 11, "il").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 12, "il").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user2", 13, "fr").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user2", 14, "fr").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user3", 7, "ar").build());
+
+    List<Row> pc3Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_3).addValues("user1", 17, "us").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user1", 18, "us").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user1", 19, "il").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user1", 20, "il").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user2", 21, "fr").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user2", 22, "fr").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user2", 23, "ar").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user2", 24, "ar").build(),
+            Row.withSchema(CG_SCHEMA_3).addValues("user3", 25, "ar").build());
+
+    PCollection<Row> pc1 = pipeline.apply("Create1", Create.of(pc1Rows)).setRowSchema(CG_SCHEMA_1);
+    PCollection<Row> pc2 = pipeline.apply("Create2", Create.of(pc2Rows)).setRowSchema(CG_SCHEMA_2);
+    PCollection<Row> pc3 = pipeline.apply("Create3", Create.of(pc3Rows)).setRowSchema(CG_SCHEMA_3);
+
+    // Partial outer join. Missing entries in the "pc2" PCollection will be filled in with nulls,
+    // but not others.
+    Schema expectedSchema =
+        Schema.builder()
+            .addField("pc1", FieldType.row(CG_SCHEMA_1))
+            .addNullableField("pc2", FieldType.row(CG_SCHEMA_2))
+            .addField("pc3", FieldType.row(CG_SCHEMA_3))
+            .build();
+
+    PCollection<Row> joined =
+        PCollectionTuple.of("pc1", pc1, "pc2", pc2, "pc3", pc3)
+            .apply(
+                "CoGroup",
+                CoGroup.join("pc1", By.fieldNames("user", "country"))
+                    .join("pc2", By.fieldNames("user2", "country2").withOuterJoinParticipation())
+                    .join("pc3", By.fieldNames("user3", "country3"))
+                    .crossProductJoin());
+    assertEquals(expectedSchema, joined.getSchema());
+
+    List<Row> expectedJoinedRows =
+        innerJoin(
+            pc1Rows,
+            pc2Rows,
+            pc3Rows,
+            new String[] {"user", "country"},
+            new String[] {"user2", "country2"},
+            new String[] {"user3", "country3"},
+            expectedSchema);
+
+    // Manually add the outer-join rows to the list of expected results. Missing results from the
+    // middle (pc2) PCollection are filled in with nulls. Missing events from other PCollections
+    // are not. Events with key ("user2", "ar) show up in pc1 and pc3 but not in pc2, so we expect
+    // the outer join to still produce those rows, with nulls for pc2. Events with key
+    // ("user3", "ar) however show up in in p2 and pc3, but not in pc1; since pc1 is marked for
+    // full participation (no outer join), these events should not be included in the join.
+    expectedJoinedRows.add(
+        Row.withSchema(expectedSchema)
+            .addValues(
+                Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(),
+                null,
+                Row.withSchema(CG_SCHEMA_3).addValues("user2", 23, "ar").build())
+            .build());
+    expectedJoinedRows.add(
+        Row.withSchema(expectedSchema)
+            .addValues(
+                Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(),
+                null,
+                Row.withSchema(CG_SCHEMA_3).addValues("user2", 24, "ar").build())
+            .build());
+    expectedJoinedRows.add(
+        Row.withSchema(expectedSchema)
+            .addValues(
+                Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build(),
+                null,
+                Row.withSchema(CG_SCHEMA_3).addValues("user2", 23, "ar").build())
+            .build());
+    expectedJoinedRows.add(
+        Row.withSchema(expectedSchema)
+            .addValues(
+                Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build(),
+                null,
+                Row.withSchema(CG_SCHEMA_3).addValues("user2", 24, "ar").build())
+            .build());
+
+    PAssert.that(joined).containsInAnyOrder(expectedJoinedRows);
+    pipeline.run();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testUnmatchedTags() {
+    PCollection<Row> pc1 = pipeline.apply("Create1", Create.empty(CG_SCHEMA_1));
+    PCollection<Row> pc2 = pipeline.apply("Create2", Create.empty(CG_SCHEMA_2));
+
+    thrown.expect(IllegalArgumentException.class);
+
+    PCollectionTuple.of("pc1", pc1, "pc2", pc2)
+        .apply(
+            CoGroup.join("pc1", By.fieldNames("user"))
+                .join("pc3", By.fieldNames("user3"))
+                .crossProductJoin());
     pipeline.run();
   }