This is an automated email from the ASF dual-hosted git repository.

reuvenlax pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new a00ba2b  Merge pull request #8273: [BEAM-4461] A transform to perform 
binary joins of PCollections with schemas
a00ba2b is described below

commit a00ba2bb57a9c1476b2b3e7ec3abe95f88938440
Author: reuvenlax <re...@google.com>
AuthorDate: Wed Apr 17 08:34:09 2019 -0700

    Merge pull request #8273: [BEAM-4461] A transform to perform binary joins 
of PCollections with schemas
---
 .../beam/sdk/schemas/transforms/CoGroup.java       |  60 ++--
 .../apache/beam/sdk/schemas/transforms/Join.java   | 242 +++++++++++++
 .../beam/sdk/schemas/transforms/CoGroupTest.java   |  40 +--
 .../beam/sdk/schemas/transforms/JoinTest.java      | 390 +++++++++++++++++++++
 .../beam/sdk/schemas/transforms/JoinTestUtils.java |  66 ++++
 5 files changed, 736 insertions(+), 62 deletions(-)

diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java
index 7dd961a..5f95855 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java
@@ -27,6 +27,7 @@ import java.util.TreeMap;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
 import org.apache.beam.sdk.schemas.Schema;
@@ -53,7 +54,7 @@ import 
org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Maps;
 /**
  * A transform that performs equijoins across multiple schema {@link 
PCollection}s.
  *
- * <p>This transform has similarites to {@link CoGroupByKey}, however works on 
PCollections that
+ * <p>This transform has similarities to {@link CoGroupByKey}, however works 
on PCollections that
  * have schemas. This allows users of the transform to simply specify schema 
fields to join on. The
  * output type of the transform is a {@code KV<Row, Row>} where the value 
contains one field for
  * every input PCollection and the key represents the fields that were joined 
on. By default the
@@ -117,7 +118,7 @@ import 
org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Maps;
  * <p>Traditional (SQL) joins are cross-product joins. All rows that match the 
join condition are
  * combined into individual rows and returned; in fact any SQL inner joins is 
a subset of the
  * cross-product of two tables. This transform also supports the same 
functionality using the {@link
- * Inner#crossProductJoin()} method.
+ * Impl#crossProductJoin()} method.
  *
  * <p>For example, consider the SQL join: SELECT * FROM input1 INNER JOIN 
input2 ON input1.user =
  * input2.user
@@ -149,7 +150,7 @@ import 
org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Maps;
  * participate fully in the join, providing inner-join semantics. This means 
that the join will only
  * produce values for "Bob" if all inputs have values for "Bob;" if even a 
single input does not
  * have a value for "Bob," an inner-join will produce no value. However, if 
you mark that input as
- * having outer-join participation then the join will contain values for 
"Bob," as long as at least
+ * having optional participation then the join will contain values for "Bob," 
as long as at least
  * one input has a "Bob" value; null values will be added for inputs that have 
no "Bob" values. To
  * continue the SQL example:
  *
@@ -159,7 +160,7 @@ import 
org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Maps;
  *
  * <pre>{@code
  * PCollection<Row> joined = PCollectionTuple.of("input1", input1, "input2", 
input2)
- *   .apply(CoGroup.join("input1", 
By.fieldNames("user").withOuterJoinParticipation())
+ *   .apply(CoGroup.join("input1", 
By.fieldNames("user").withOptionalParticipation())
  *                 .join("input2", By.fieldNames("user"))
  *                 .crossProductJoin();
  * }</pre>
@@ -171,7 +172,7 @@ import 
org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Maps;
  * <pre>{@code
  * PCollection<Row> joined = PCollectionTuple.of("input1", input1, "input2", 
input2)
  *   .apply(CoGroup.join("input1", By.fieldNames("user"))
- *                 .join("input2", 
By.fieldNames("user").withOuterJoinParticipation())
+ *                 .join("input2", 
By.fieldNames("user").withOptionalParticipation())
  *                 .crossProductJoin();
  * }</pre>
  *
@@ -181,17 +182,18 @@ import 
org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Maps;
  *
  * <pre>{@code
  * PCollection<Row> joined = PCollectionTuple.of("input1", input1, "input2", 
input2)
- *   .apply(CoGroup.join("input1", 
By.fieldNames("user").withOuterJoinParticipation())
- *                 .join("input2", 
By.fieldNames("user").withOuterJoinParticipation())
+ *   .apply(CoGroup.join("input1", 
By.fieldNames("user").withOptionalParticipation())
+ *                 .join("input2", 
By.fieldNames("user").withOptionalParticipation())
  *                 .crossProductJoin();
  * }</pre>
  *
  * <p>While the above examples use two inputs to mimic SQL's left and right 
join semantics, the
- * {@link CoGroup} transform supports any number of inputs, and outer-join 
participation can be
+ * {@link CoGroup} transform supports any number of inputs, and optional 
participation can be
  * specified on any subset of them.
  *
  * <p>Do note that cross-product joins while simpler and easier to program, 
can cause
  */
+@Experimental(Experimental.Kind.SCHEMAS)
 public class CoGroup {
   private static final List NULL_LIST;
 
@@ -207,7 +209,7 @@ public class CoGroup {
   public abstract static class By implements Serializable {
     abstract FieldAccessDescriptor getFieldAccessDescriptor();
 
-    abstract boolean getOuterJoinParticipation();
+    abstract boolean getOptionalParticipation();
 
     abstract Builder toBuilder();
 
@@ -215,7 +217,7 @@ public class CoGroup {
     abstract static class Builder {
       abstract Builder setFieldAccessDescriptor(FieldAccessDescriptor 
fieldAccessDescriptor);
 
-      abstract Builder setOuterJoinParticipation(boolean 
outerJoinParticipation);
+      abstract Builder setOptionalParticipation(boolean optionalParticipation);
 
       abstract By build();
     }
@@ -234,7 +236,7 @@ public class CoGroup {
     public static By fieldAccessDescriptor(FieldAccessDescriptor 
fieldAccessDescriptor) {
       return new AutoValue_CoGroup_By.Builder()
           .setFieldAccessDescriptor(fieldAccessDescriptor)
-          .setOuterJoinParticipation(false)
+          .setOptionalParticipation(false)
           .build();
     }
 
@@ -244,8 +246,8 @@ public class CoGroup {
      *
      * <p>This only affects the results of expandCrossProduct.
      */
-    public By withOuterJoinParticipation() {
-      return toBuilder().setOuterJoinParticipation(true).build();
+    public By withOptionalParticipation() {
+      return toBuilder().setOptionalParticipation(true).build();
     }
   }
 
@@ -275,10 +277,10 @@ public class CoGroup {
           : joinArgsMap.get(tag).getFieldAccessDescriptor();
     }
 
-    private boolean getOuterJoinParticipation(String tag) {
+    private boolean getOptionalParticipation(String tag) {
       return (allInputsJoinArgs != null)
-          ? allInputsJoinArgs.getOuterJoinParticipation()
-          : joinArgsMap.get(tag).getOuterJoinParticipation();
+          ? allInputsJoinArgs.getOptionalParticipation()
+          : joinArgsMap.get(tag).getOptionalParticipation();
     }
   }
 
@@ -287,8 +289,8 @@ public class CoGroup {
    *
    * <p>The same fields and other options are used in all input PCollections.
    */
-  public static Inner join(By clause) {
-    return new Inner(new JoinArguments(clause));
+  public static Impl join(By clause) {
+    return new Impl(new JoinArguments(clause));
   }
 
   /**
@@ -297,8 +299,8 @@ public class CoGroup {
    *
    * <p>Each PCollection in the input must have args specified for the join 
key.
    */
-  public static Inner join(String tag, By clause) {
-    return new Inner(new JoinArguments(ImmutableMap.of(tag, clause)));
+  public static Impl join(String tag, By clause) {
+    return new Impl(new JoinArguments(ImmutableMap.of(tag, clause)));
   }
 
   // Contains summary information needed for implementing the join.
@@ -421,14 +423,14 @@ public class CoGroup {
   }
 
   /** The implementing PTransform. */
-  public static class Inner extends PTransform<PCollectionTuple, 
PCollection<KV<Row, Row>>> {
+  public static class Impl extends PTransform<PCollectionTuple, 
PCollection<KV<Row, Row>>> {
     private final JoinArguments joinArgs;
 
-    private Inner() {
+    private Impl() {
       this(new JoinArguments(Collections.emptyMap()));
     }
 
-    private Inner(JoinArguments joinArgs) {
+    private Impl(JoinArguments joinArgs) {
       this.joinArgs = joinArgs;
     }
 
@@ -437,11 +439,11 @@ public class CoGroup {
      *
      * <p>Each PCollection in the input must have fields specified for the 
join key.
      */
-    public Inner join(String tag, By clause) {
+    public Impl join(String tag, By clause) {
       if (joinArgs.allInputsJoinArgs != null) {
         throw new IllegalStateException("Cannot set both a global and per-tag 
fields.");
       }
-      return new Inner(joinArgs.with(tag, clause));
+      return new Impl(joinArgs.with(tag, clause));
     }
 
     /** Expand the join into individual rows, similar to SQL joins. */
@@ -546,12 +548,12 @@ public class CoGroup {
 
     private Schema getOutputSchema(JoinInformation joinInformation) {
       // Construct the output schema. It contains one field for each input 
PCollection, of type
-      // ROW. If a field supports outer-join semantics, then that field will 
be nullable in the
+      // ROW. If a field has optional participation, then that field will be 
nullable in the
       // schema.
       Schema.Builder joinedSchemaBuilder = Schema.builder();
       for (Map.Entry<String, Schema> entry : 
joinInformation.componentSchemas.entrySet()) {
         FieldType fieldType = FieldType.row(entry.getValue());
-        if (joinArgs.getOuterJoinParticipation(entry.getKey())) {
+        if (joinArgs.getOptionalParticipation(entry.getKey())) {
           fieldType = fieldType.withNullable(true);
         }
         joinedSchemaBuilder.addField(entry.getKey(), fieldType);
@@ -613,8 +615,8 @@ public class CoGroup {
         for (int i = 0; i < sortedTags.size(); ++i) {
           String tag = sortedTags.get(i);
           Iterable items = gbkResult.getAll(tagToKeyedTag.get(i));
-          if (!items.iterator().hasNext() && 
joinArgs.getOuterJoinParticipation(tag)) {
-            // If this tag has outer-join participation, then empty should 
participate as a
+          if (!items.iterator().hasNext() && 
joinArgs.getOptionalParticipation(tag)) {
+            // If this tag has optional participation, then empty should 
participate as a
             // single null.
             items = () -> NULL_LIST.iterator();
           }
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Join.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Join.java
new file mode 100644
index 0000000..f1b7ee6
--- /dev/null
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Join.java
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.schemas.transforms;
+
+import java.io.Serializable;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.Row;
+
+/**
+ * A transform that performs equijoins across two schema {@link PCollection}s.
+ *
+ * <p>This transform allows joins between two input PCollections simply by 
specifying the fields to
+ * join on. The resulting {@code PCollection<Row>} will have two fields named 
"lhs" and "rhs"
+ * respectively, each with the schema of the corresponding input PCollection.
+ *
+ * <p>For example, the following demonstrates joining two PCollections using a 
natural join on the
+ * "user" and "country" fields, where both the left-hand and the right-hand 
PCollections have fields
+ * with these names.
+ *
+ * <pre>
+ * {@code PCollection<Row> joined = 
pCollection1.apply(Join.innerJoin(pCollection2).using("user", "country"));
+ * }</pre>
+ *
+ * <p>If the right-hand PCollection contains fields with different names to 
join against, you can
+ * specify them as follows:
+ *
+ * <pre>{@code PCollection<Row> joined = 
pCollection1.apply(Join.innerJoin(pCollection2)
+ *       .on(FieldsEqual.left("user", "country").right("otherUser", 
"otherCountry")));
+ * }</pre>
+ *
+ * <p>Full outer joins, left outer joins, and right outer joins are also 
supported.
+ */
+@Experimental(Experimental.Kind.SCHEMAS)
+public class Join {
+  public static final String LHS_TAG = "lhs";
+  public static final String RHS_TAG = "rhs";
+
+  /** Predicate object to specify fields to compare when doing an equi-join. */
+  public static class FieldsEqual {
+    public static Impl left(String... fieldNames) {
+      return new Impl(
+          FieldAccessDescriptor.withFieldNames(fieldNames), 
FieldAccessDescriptor.create());
+    }
+
+    public static Impl left(Integer... fieldIds) {
+      return new Impl(FieldAccessDescriptor.withFieldIds(fieldIds), 
FieldAccessDescriptor.create());
+    }
+
+    public static Impl left(FieldAccessDescriptor fieldAccessDescriptor) {
+      return new Impl(fieldAccessDescriptor, FieldAccessDescriptor.create());
+    }
+
+    public Impl right(String... fieldNames) {
+      return new Impl(
+          FieldAccessDescriptor.create(), 
FieldAccessDescriptor.withFieldNames(fieldNames));
+    }
+
+    public Impl right(Integer... fieldIds) {
+      return new Impl(FieldAccessDescriptor.create(), 
FieldAccessDescriptor.withFieldIds(fieldIds));
+    }
+
+    public Impl right(FieldAccessDescriptor fieldAccessDescriptor) {
+      return new Impl(FieldAccessDescriptor.create(), fieldAccessDescriptor);
+    }
+
+    /** Implementation class for FieldsEqual. */
+    public static class Impl implements Serializable {
+      private FieldAccessDescriptor lhs;
+      private FieldAccessDescriptor rhs;
+
+      private Impl(FieldAccessDescriptor lhs, FieldAccessDescriptor rhs) {
+        this.lhs = lhs;
+        this.rhs = rhs;
+      }
+
+      public Impl left(String... fieldNames) {
+        return new Impl(FieldAccessDescriptor.withFieldNames(fieldNames), rhs);
+      }
+
+      public Impl left(Integer... fieldIds) {
+        return new Impl(FieldAccessDescriptor.withFieldIds(fieldIds), rhs);
+      }
+
+      public Impl left(FieldAccessDescriptor fieldAccessDescriptor) {
+        return new Impl(fieldAccessDescriptor, rhs);
+      }
+
+      public Impl right(String... fieldNames) {
+        return new Impl(lhs, FieldAccessDescriptor.withFieldNames(fieldNames));
+      }
+
+      public Impl right(Integer... fieldIds) {
+        return new Impl(lhs, FieldAccessDescriptor.withFieldIds(fieldIds));
+      }
+
+      public Impl right(FieldAccessDescriptor fieldAccessDescriptor) {
+        return new Impl(lhs, fieldAccessDescriptor);
+      }
+
+      private Impl resolve(Schema lhsSchema, Schema rhsSchema) {
+        return new Impl(lhs.resolve(lhsSchema), rhs.resolve(rhsSchema));
+      }
+    }
+  }
+
+  /** Perform an inner join. */
+  public static <LhsT, RhsT> Impl<LhsT, RhsT> innerJoin(PCollection<RhsT> rhs) 
{
+    return new Impl<>(JoinType.INNER, rhs);
+  }
+
+  /** Perform a full outer join. */
+  public static <LhsT, RhsT> Impl<LhsT, RhsT> fullOuterJoin(PCollection<RhsT> 
rhs) {
+    return new Impl<>(JoinType.OUTER, rhs);
+  }
+
+  /** Perform a left outer join. */
+  public static <LhsT, RhsT> Impl<LhsT, RhsT> leftOuterJoin(PCollection<RhsT> 
rhs) {
+    return new Impl<>(JoinType.LEFT_OUTER, rhs);
+  }
+
+  /** Perform a right outer join. */
+  public static <LhsT, RhsT> Impl<LhsT, RhsT> rightOuterJoin(PCollection<RhsT> 
rhs) {
+    return new Impl<>(JoinType.RIGHT_OUTER, rhs);
+  };
+
+  private enum JoinType {
+    INNER,
+    OUTER,
+    LEFT_OUTER,
+    RIGHT_OUTER
+  };
+
+  /** Implementation class . */
+  public static class Impl<LhsT, RhsT> extends PTransform<PCollection<LhsT>, 
PCollection<Row>> {
+    private final JoinType joinType;
+    private final transient PCollection<RhsT> rhs;
+    @Nullable private final FieldsEqual.Impl predicate;
+
+    private Impl(JoinType joinType, PCollection<RhsT> rhs) {
+      this(joinType, rhs, null);
+    }
+
+    private Impl(JoinType joinType, PCollection<RhsT> rhs, FieldsEqual.Impl 
predicate) {
+      this.joinType = joinType;
+      this.rhs = rhs;
+      this.predicate = predicate;
+    }
+
+    /**
+     * Perform a natural join between the PCollections. The fields are 
expected to exist in both
+     * PCollections
+     */
+    public Impl<LhsT, RhsT> using(String... fieldNames) {
+      return new Impl<>(joinType, rhs, 
FieldsEqual.left(fieldNames).right(fieldNames));
+    }
+
+    /**
+     * Perform a natural join between the PCollections. The fields are 
expected to exist in both
+     * PCollections
+     */
+    public Impl<LhsT, RhsT> using(Integer... fieldIds) {
+      return new Impl<>(joinType, rhs, 
FieldsEqual.left(fieldIds).right(fieldIds));
+    }
+
+    /**
+     * Perform a natural join between the PCollections. The fields are 
expected to exist in both
+     * PCollections
+     */
+    public Impl<LhsT, RhsT> using(FieldAccessDescriptor fieldAccessDescriptor) 
{
+      return new Impl<>(
+          joinType, rhs, 
FieldsEqual.left(fieldAccessDescriptor).right(fieldAccessDescriptor));
+    }
+
+    /** Join the PCollections using the provided predicate. */
+    public Impl<LhsT, RhsT> on(FieldsEqual.Impl predicate) {
+      return new Impl<>(joinType, rhs, predicate);
+    }
+
+    @Override
+    public PCollection<Row> expand(PCollection lhs) {
+      FieldsEqual.Impl resolvedPredicate = predicate.resolve(lhs.getSchema(), 
rhs.getSchema());
+      PCollectionTuple tuple = PCollectionTuple.of(LHS_TAG, lhs).and(RHS_TAG, 
rhs);
+      switch (joinType) {
+        case INNER:
+          return tuple.apply(
+              CoGroup.join(LHS_TAG, 
CoGroup.By.fieldAccessDescriptor(resolvedPredicate.lhs))
+                  .join(RHS_TAG, 
CoGroup.By.fieldAccessDescriptor(resolvedPredicate.rhs))
+                  .crossProductJoin());
+        case OUTER:
+          return tuple.apply(
+              CoGroup.join(
+                      LHS_TAG,
+                      CoGroup.By.fieldAccessDescriptor(resolvedPredicate.lhs)
+                          .withOptionalParticipation())
+                  .join(
+                      RHS_TAG,
+                      CoGroup.By.fieldAccessDescriptor(resolvedPredicate.rhs)
+                          .withOptionalParticipation())
+                  .crossProductJoin());
+        case LEFT_OUTER:
+          return tuple.apply(
+              CoGroup.join(LHS_TAG, 
CoGroup.By.fieldAccessDescriptor(resolvedPredicate.lhs))
+                  .join(
+                      RHS_TAG,
+                      CoGroup.By.fieldAccessDescriptor(resolvedPredicate.rhs)
+                          .withOptionalParticipation())
+                  .crossProductJoin());
+        case RIGHT_OUTER:
+          return tuple.apply(
+              CoGroup.join(
+                      LHS_TAG,
+                      CoGroup.By.fieldAccessDescriptor(resolvedPredicate.lhs)
+                          .withOptionalParticipation())
+                  .join(RHS_TAG, 
CoGroup.By.fieldAccessDescriptor(resolvedPredicate.rhs))
+                  .crossProductJoin());
+        default:
+          throw new RuntimeException("Unexpected join type");
+      }
+    }
+  }
+}
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java
index 2251c6c..184f215 100644
--- 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java
@@ -23,9 +23,7 @@ import static org.hamcrest.CoreMatchers.equalTo;
 import static 
org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder;
 import static org.junit.Assert.assertThat;
 
-import java.util.Arrays;
 import java.util.List;
-import java.util.stream.Collectors;
 import org.apache.beam.sdk.TestUtils.KvMatcher;
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.schemas.Schema.FieldType;
@@ -397,30 +395,6 @@ public class CoGroupTest {
     pipeline.run();
   }
 
-  private List<Row> innerJoin(
-      List<Row> inputs1,
-      List<Row> inputs2,
-      List<Row> inputs3,
-      String[] keys1,
-      String[] keys2,
-      String[] keys3,
-      Schema expectedSchema) {
-    List<Row> joined = Lists.newArrayList();
-    for (Row row1 : inputs1) {
-      for (Row row2 : inputs2) {
-        for (Row row3 : inputs3) {
-          List key1 = 
Arrays.stream(keys1).map(row1::getValue).collect(Collectors.toList());
-          List key2 = 
Arrays.stream(keys2).map(row2::getValue).collect(Collectors.toList());
-          List key3 = 
Arrays.stream(keys3).map(row3::getValue).collect(Collectors.toList());
-          if (key1.equals(key2) && key2.equals(key3)) {
-            joined.add(Row.withSchema(expectedSchema).addValues(row1, row2, 
row3).build());
-          }
-        }
-      }
-    }
-    return joined;
-  }
-
   @Test
   @Category(NeedsRunner.class)
   public void testInnerJoin() {
@@ -477,7 +451,7 @@ public class CoGroupTest {
     assertEquals(expectedSchema, joined.getSchema());
 
     List<Row> expectedJoinedRows =
-        innerJoin(
+        JoinTestUtils.innerJoin(
             pc1Rows,
             pc2Rows,
             pc3Rows,
@@ -545,14 +519,14 @@ public class CoGroupTest {
         PCollectionTuple.of("pc1", pc1, "pc2", pc2, "pc3", pc3)
             .apply(
                 "CoGroup",
-                CoGroup.join("pc1", By.fieldNames("user", 
"country").withOuterJoinParticipation())
-                    .join("pc2", By.fieldNames("user2", 
"country2").withOuterJoinParticipation())
-                    .join("pc3", By.fieldNames("user3", 
"country3").withOuterJoinParticipation())
+                CoGroup.join("pc1", By.fieldNames("user", 
"country").withOptionalParticipation())
+                    .join("pc2", By.fieldNames("user2", 
"country2").withOptionalParticipation())
+                    .join("pc3", By.fieldNames("user3", 
"country3").withOptionalParticipation())
                     .crossProductJoin());
     assertEquals(expectedSchema, joined.getSchema());
 
     List<Row> expectedJoinedRows =
-        innerJoin(
+        JoinTestUtils.innerJoin(
             pc1Rows,
             pc2Rows,
             pc3Rows,
@@ -633,13 +607,13 @@ public class CoGroupTest {
             .apply(
                 "CoGroup",
                 CoGroup.join("pc1", By.fieldNames("user", "country"))
-                    .join("pc2", By.fieldNames("user2", 
"country2").withOuterJoinParticipation())
+                    .join("pc2", By.fieldNames("user2", 
"country2").withOptionalParticipation())
                     .join("pc3", By.fieldNames("user3", "country3"))
                     .crossProductJoin());
     assertEquals(expectedSchema, joined.getSchema());
 
     List<Row> expectedJoinedRows =
-        innerJoin(
+        JoinTestUtils.innerJoin(
             pc1Rows,
             pc2Rows,
             pc3Rows,
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/JoinTest.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/JoinTest.java
new file mode 100644
index 0000000..0b2631f
--- /dev/null
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/JoinTest.java
@@ -0,0 +1,390 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.schemas.transforms;
+
+import static junit.framework.TestCase.assertEquals;
+import static org.apache.beam.sdk.schemas.transforms.JoinTestUtils.innerJoin;
+
+import java.util.List;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.transforms.Join.FieldsEqual;
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Lists;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+
+/** Tests for {@link org.apache.beam.sdk.schemas.transforms.Join}. */
+public class JoinTest {
+  @Rule public final transient TestPipeline pipeline = TestPipeline.create();
+
+  private static final Schema CG_SCHEMA_1 =
+      Schema.builder()
+          .addStringField("user")
+          .addInt32Field("count")
+          .addStringField("country")
+          .build();
+  private static final Schema CG_SCHEMA_2 =
+      Schema.builder()
+          .addStringField("user2")
+          .addInt32Field("count2")
+          .addStringField("country2")
+          .build();
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testInnerJoinSameKeys() {
+    List<Row> pc1Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 1, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 2, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 3, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 4, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 5, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 6, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user3", 8, "ar").build());
+    List<Row> pc2Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 9, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 10, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 11, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 12, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 13, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 14, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 15, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 16, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user4", 8, "ar").build());
+
+    PCollection<Row> pc1 = pipeline.apply("Create1", 
Create.of(pc1Rows)).setRowSchema(CG_SCHEMA_1);
+    PCollection<Row> pc2 = pipeline.apply("Create2", 
Create.of(pc2Rows)).setRowSchema(CG_SCHEMA_1);
+
+    Schema expectedSchema =
+        Schema.builder()
+            .addRowField(Join.LHS_TAG, CG_SCHEMA_1)
+            .addRowField(Join.RHS_TAG, CG_SCHEMA_1)
+            .build();
+
+    PCollection<Row> joined = pc1.apply(Join.<Row, 
Row>innerJoin(pc2).using("user", "country"));
+
+    assertEquals(expectedSchema, joined.getSchema());
+
+    List<Row> expectedJoinedRows =
+        innerJoin(
+            pc1Rows,
+            pc2Rows,
+            new String[] {"user", "country"},
+            new String[] {"user", "country"},
+            expectedSchema);
+
+    PAssert.that(joined).containsInAnyOrder(expectedJoinedRows);
+    pipeline.run();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testInnerJoinDifferentKeys() {
+    List<Row> pc1Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 1, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 2, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 3, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 4, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 5, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 6, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user3", 8, "ar").build());
+    List<Row> pc2Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 9, "us").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 10, "us").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 11, "il").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 12, "il").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user2", 13, "fr").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user2", 14, "fr").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user2", 15, "ar").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user2", 16, "ar").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user4", 8, "ar").build());
+
+    PCollection<Row> pc1 = pipeline.apply("Create1", 
Create.of(pc1Rows)).setRowSchema(CG_SCHEMA_1);
+    PCollection<Row> pc2 = pipeline.apply("Create2", 
Create.of(pc2Rows)).setRowSchema(CG_SCHEMA_2);
+
+    Schema expectedSchema =
+        Schema.builder()
+            .addRowField(Join.LHS_TAG, CG_SCHEMA_1)
+            .addRowField(Join.RHS_TAG, CG_SCHEMA_2)
+            .build();
+
+    PCollection<Row> joined =
+        pc1.apply(
+            Join.<Row, Row>innerJoin(pc2)
+                .on(FieldsEqual.left("user", "country").right("user2", 
"country2")));
+
+    assertEquals(expectedSchema, joined.getSchema());
+
+    List<Row> expectedJoinedRows =
+        innerJoin(
+            pc1Rows,
+            pc2Rows,
+            new String[] {"user", "country"},
+            new String[] {"user2", "country2"},
+            expectedSchema);
+
+    PAssert.that(joined).containsInAnyOrder(expectedJoinedRows);
+    pipeline.run();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testOuterJoinDifferentKeys() {
+    List<Row> pc1Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 1, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 2, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 3, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 4, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 5, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 6, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user3", 8, "ar").build());
+    List<Row> pc2Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 9, "us").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 10, "us").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 11, "il").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user1", 12, "il").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user2", 13, "fr").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user2", 14, "fr").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user2", 15, "ar").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user2", 16, "ar").build(),
+            Row.withSchema(CG_SCHEMA_2).addValues("user4", 8, "ar").build());
+
+    PCollection<Row> pc1 = pipeline.apply("Create1", 
Create.of(pc1Rows)).setRowSchema(CG_SCHEMA_1);
+    PCollection<Row> pc2 = pipeline.apply("Create2", 
Create.of(pc2Rows)).setRowSchema(CG_SCHEMA_2);
+
+    Schema expectedSchema =
+        Schema.builder()
+            .addNullableField(Join.LHS_TAG, Schema.FieldType.row(CG_SCHEMA_1))
+            .addNullableField(Join.RHS_TAG, Schema.FieldType.row(CG_SCHEMA_2))
+            .build();
+
+    PCollection<Row> joined =
+        pc1.apply(
+            Join.<Row, Row>fullOuterJoin(pc2)
+                .on(FieldsEqual.left("user", "country").right("user2", 
"country2")));
+
+    assertEquals(expectedSchema, joined.getSchema());
+
+    List<Row> expectedJoinedRows =
+        innerJoin(
+            pc1Rows,
+            pc2Rows,
+            new String[] {"user", "country"},
+            new String[] {"user2", "country2"},
+            expectedSchema);
+    expectedJoinedRows.add(
+        Row.withSchema(expectedSchema)
+            .addValues(Row.withSchema(CG_SCHEMA_1).addValues("user3", 8, 
"ar").build(), null)
+            .build());
+    expectedJoinedRows.add(
+        Row.withSchema(expectedSchema)
+            .addValues(null, Row.withSchema(CG_SCHEMA_2).addValues("user4", 8, 
"ar").build())
+            .build());
+
+    PAssert.that(joined).containsInAnyOrder(expectedJoinedRows);
+    pipeline.run();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testOuterJoinSameKeys() {
+    List<Row> pc1Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 1, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 2, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 3, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 4, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 5, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 6, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user3", 8, "ar").build());
+    List<Row> pc2Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 9, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 10, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 11, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 12, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 13, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 14, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 15, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 16, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user4", 8, "ar").build());
+
+    PCollection<Row> pc1 = pipeline.apply("Create1", 
Create.of(pc1Rows)).setRowSchema(CG_SCHEMA_1);
+    PCollection<Row> pc2 = pipeline.apply("Create2", 
Create.of(pc2Rows)).setRowSchema(CG_SCHEMA_1);
+
+    Schema expectedSchema =
+        Schema.builder()
+            .addNullableField(Join.LHS_TAG, Schema.FieldType.row(CG_SCHEMA_1))
+            .addNullableField(Join.RHS_TAG, Schema.FieldType.row(CG_SCHEMA_1))
+            .build();
+
+    PCollection<Row> joined = pc1.apply(Join.<Row, 
Row>fullOuterJoin(pc2).using("user", "country"));
+
+    assertEquals(expectedSchema, joined.getSchema());
+
+    List<Row> expectedJoinedRows =
+        innerJoin(
+            pc1Rows,
+            pc2Rows,
+            new String[] {"user", "country"},
+            new String[] {"user", "country"},
+            expectedSchema);
+    expectedJoinedRows.add(
+        Row.withSchema(expectedSchema)
+            .addValues(Row.withSchema(CG_SCHEMA_1).addValues("user3", 8, 
"ar").build(), null)
+            .build());
+    expectedJoinedRows.add(
+        Row.withSchema(expectedSchema)
+            .addValues(null, Row.withSchema(CG_SCHEMA_1).addValues("user4", 8, 
"ar").build())
+            .build());
+
+    PAssert.that(joined).containsInAnyOrder(expectedJoinedRows);
+    pipeline.run();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testLeftOuterJoinSameKeys() {
+    List<Row> pc1Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 1, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 2, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 3, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 4, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 5, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 6, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user3", 8, "ar").build());
+    List<Row> pc2Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 9, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 10, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 11, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 12, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 13, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 14, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 15, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 16, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user4", 8, "ar").build());
+
+    PCollection<Row> pc1 = pipeline.apply("Create1", 
Create.of(pc1Rows)).setRowSchema(CG_SCHEMA_1);
+    PCollection<Row> pc2 = pipeline.apply("Create2", 
Create.of(pc2Rows)).setRowSchema(CG_SCHEMA_1);
+
+    Schema expectedSchema =
+        Schema.builder()
+            .addField(Join.LHS_TAG, Schema.FieldType.row(CG_SCHEMA_1))
+            .addNullableField(Join.RHS_TAG, Schema.FieldType.row(CG_SCHEMA_1))
+            .build();
+
+    PCollection<Row> joined = pc1.apply(Join.<Row, 
Row>leftOuterJoin(pc2).using("user", "country"));
+
+    assertEquals(expectedSchema, joined.getSchema());
+
+    List<Row> expectedJoinedRows =
+        innerJoin(
+            pc1Rows,
+            pc2Rows,
+            new String[] {"user", "country"},
+            new String[] {"user", "country"},
+            expectedSchema);
+    expectedJoinedRows.add(
+        Row.withSchema(expectedSchema)
+            .addValues(Row.withSchema(CG_SCHEMA_1).addValues("user3", 8, 
"ar").build(), null)
+            .build());
+
+    PAssert.that(joined).containsInAnyOrder(expectedJoinedRows);
+    pipeline.run();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testRightOuterJoinSameKeys() {
+    List<Row> pc1Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 1, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 2, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 3, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 4, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 5, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 6, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user3", 8, "ar").build());
+    List<Row> pc2Rows =
+        Lists.newArrayList(
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 9, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 10, "us").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 11, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user1", 12, "il").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 13, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 14, "fr").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 15, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user2", 16, "ar").build(),
+            Row.withSchema(CG_SCHEMA_1).addValues("user4", 8, "ar").build());
+
+    PCollection<Row> pc1 = pipeline.apply("Create1", 
Create.of(pc1Rows)).setRowSchema(CG_SCHEMA_1);
+    PCollection<Row> pc2 = pipeline.apply("Create2", 
Create.of(pc2Rows)).setRowSchema(CG_SCHEMA_1);
+
+    Schema expectedSchema =
+        Schema.builder()
+            .addNullableField(Join.LHS_TAG, Schema.FieldType.row(CG_SCHEMA_1))
+            .addField(Join.RHS_TAG, Schema.FieldType.row(CG_SCHEMA_1))
+            .build();
+
+    PCollection<Row> joined =
+        pc1.apply(Join.<Row, Row>rightOuterJoin(pc2).using("user", "country"));
+
+    assertEquals(expectedSchema, joined.getSchema());
+
+    List<Row> expectedJoinedRows =
+        innerJoin(
+            pc1Rows,
+            pc2Rows,
+            new String[] {"user", "country"},
+            new String[] {"user", "country"},
+            expectedSchema);
+    expectedJoinedRows.add(
+        Row.withSchema(expectedSchema)
+            .addValues(null, Row.withSchema(CG_SCHEMA_1).addValues("user4", 8, 
"ar").build())
+            .build());
+
+    PAssert.that(joined).containsInAnyOrder(expectedJoinedRows);
+    pipeline.run();
+  }
+}
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/JoinTestUtils.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/JoinTestUtils.java
new file mode 100644
index 0000000..2ef6370
--- /dev/null
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/JoinTestUtils.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.schemas.transforms;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Lists;
+
+class JoinTestUtils {
+  static List<Row> innerJoin(
+      List<Row> inputs1, List<Row> inputs2, String[] keys1, String[] keys2, 
Schema expectedSchema) {
+    List<Row> joined = Lists.newArrayList();
+    for (Row row1 : inputs1) {
+      for (Row row2 : inputs2) {
+        List key1 = 
Arrays.stream(keys1).map(row1::getValue).collect(Collectors.toList());
+        List key2 = 
Arrays.stream(keys2).map(row2::getValue).collect(Collectors.toList());
+        if (key1.equals(key2)) {
+          joined.add(Row.withSchema(expectedSchema).addValues(row1, 
row2).build());
+        }
+      }
+    }
+    return joined;
+  }
+
+  static List<Row> innerJoin(
+      List<Row> inputs1,
+      List<Row> inputs2,
+      List<Row> inputs3,
+      String[] keys1,
+      String[] keys2,
+      String[] keys3,
+      Schema expectedSchema) {
+    List<Row> joined = Lists.newArrayList();
+    for (Row row1 : inputs1) {
+      for (Row row2 : inputs2) {
+        for (Row row3 : inputs3) {
+          List key1 = 
Arrays.stream(keys1).map(row1::getValue).collect(Collectors.toList());
+          List key2 = 
Arrays.stream(keys2).map(row2::getValue).collect(Collectors.toList());
+          List key3 = 
Arrays.stream(keys3).map(row3::getValue).collect(Collectors.toList());
+          if (key1.equals(key2) && key2.equals(key3)) {
+            joined.add(Row.withSchema(expectedSchema).addValues(row1, row2, 
row3).build());
+          }
+        }
+      }
+    }
+    return joined;
+  }
+}

Reply via email to