This is an automated email from the ASF dual-hosted git repository. reuvenlax pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new a00ba2b Merge pull request #8273: [BEAM-4461] A transform to perform binary joins of PCollections with schemas a00ba2b is described below commit a00ba2bb57a9c1476b2b3e7ec3abe95f88938440 Author: reuvenlax <re...@google.com> AuthorDate: Wed Apr 17 08:34:09 2019 -0700 Merge pull request #8273: [BEAM-4461] A transform to perform binary joins of PCollections with schemas --- .../beam/sdk/schemas/transforms/CoGroup.java | 60 ++-- .../apache/beam/sdk/schemas/transforms/Join.java | 242 +++++++++++++ .../beam/sdk/schemas/transforms/CoGroupTest.java | 40 +-- .../beam/sdk/schemas/transforms/JoinTest.java | 390 +++++++++++++++++++++ .../beam/sdk/schemas/transforms/JoinTestUtils.java | 66 ++++ 5 files changed, 736 insertions(+), 62 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java index 7dd961a..5f95855 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java @@ -27,6 +27,7 @@ import java.util.TreeMap; import java.util.function.Function; import java.util.stream.Collectors; import javax.annotation.Nullable; +import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.schemas.FieldAccessDescriptor; import org.apache.beam.sdk.schemas.Schema; @@ -53,7 +54,7 @@ import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Maps; /** * A transform that performs equijoins across multiple schema {@link PCollection}s. * - * <p>This transform has similarites to {@link CoGroupByKey}, however works on PCollections that + * <p>This transform has similarities to {@link CoGroupByKey}, however works on PCollections that * have schemas. This allows users of the transform to simply specify schema fields to join on. The * output type of the transform is a {@code KV<Row, Row>} where the value contains one field for * every input PCollection and the key represents the fields that were joined on. By default the @@ -117,7 +118,7 @@ import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Maps; * <p>Traditional (SQL) joins are cross-product joins. All rows that match the join condition are * combined into individual rows and returned; in fact any SQL inner joins is a subset of the * cross-product of two tables. This transform also supports the same functionality using the {@link - * Inner#crossProductJoin()} method. + * Impl#crossProductJoin()} method. * * <p>For example, consider the SQL join: SELECT * FROM input1 INNER JOIN input2 ON input1.user = * input2.user @@ -149,7 +150,7 @@ import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Maps; * participate fully in the join, providing inner-join semantics. This means that the join will only * produce values for "Bob" if all inputs have values for "Bob;" if even a single input does not * have a value for "Bob," an inner-join will produce no value. However, if you mark that input as - * having outer-join participation then the join will contain values for "Bob," as long as at least + * having optional participation then the join will contain values for "Bob," as long as at least * one input has a "Bob" value; null values will be added for inputs that have no "Bob" values. To * continue the SQL example: * @@ -159,7 +160,7 @@ import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Maps; * * <pre>{@code * PCollection<Row> joined = PCollectionTuple.of("input1", input1, "input2", input2) - * .apply(CoGroup.join("input1", By.fieldNames("user").withOuterJoinParticipation()) + * .apply(CoGroup.join("input1", By.fieldNames("user").withOptionalParticipation()) * .join("input2", By.fieldNames("user")) * .crossProductJoin(); * }</pre> @@ -171,7 +172,7 @@ import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Maps; * <pre>{@code * PCollection<Row> joined = PCollectionTuple.of("input1", input1, "input2", input2) * .apply(CoGroup.join("input1", By.fieldNames("user")) - * .join("input2", By.fieldNames("user").withOuterJoinParticipation()) + * .join("input2", By.fieldNames("user").withOptionalParticipation()) * .crossProductJoin(); * }</pre> * @@ -181,17 +182,18 @@ import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Maps; * * <pre>{@code * PCollection<Row> joined = PCollectionTuple.of("input1", input1, "input2", input2) - * .apply(CoGroup.join("input1", By.fieldNames("user").withOuterJoinParticipation()) - * .join("input2", By.fieldNames("user").withOuterJoinParticipation()) + * .apply(CoGroup.join("input1", By.fieldNames("user").withOptionalParticipation()) + * .join("input2", By.fieldNames("user").withOptionalParticipation()) * .crossProductJoin(); * }</pre> * * <p>While the above examples use two inputs to mimic SQL's left and right join semantics, the - * {@link CoGroup} transform supports any number of inputs, and outer-join participation can be + * {@link CoGroup} transform supports any number of inputs, and optional participation can be * specified on any subset of them. * * <p>Do note that cross-product joins while simpler and easier to program, can cause */ +@Experimental(Experimental.Kind.SCHEMAS) public class CoGroup { private static final List NULL_LIST; @@ -207,7 +209,7 @@ public class CoGroup { public abstract static class By implements Serializable { abstract FieldAccessDescriptor getFieldAccessDescriptor(); - abstract boolean getOuterJoinParticipation(); + abstract boolean getOptionalParticipation(); abstract Builder toBuilder(); @@ -215,7 +217,7 @@ public class CoGroup { abstract static class Builder { abstract Builder setFieldAccessDescriptor(FieldAccessDescriptor fieldAccessDescriptor); - abstract Builder setOuterJoinParticipation(boolean outerJoinParticipation); + abstract Builder setOptionalParticipation(boolean optionalParticipation); abstract By build(); } @@ -234,7 +236,7 @@ public class CoGroup { public static By fieldAccessDescriptor(FieldAccessDescriptor fieldAccessDescriptor) { return new AutoValue_CoGroup_By.Builder() .setFieldAccessDescriptor(fieldAccessDescriptor) - .setOuterJoinParticipation(false) + .setOptionalParticipation(false) .build(); } @@ -244,8 +246,8 @@ public class CoGroup { * * <p>This only affects the results of expandCrossProduct. */ - public By withOuterJoinParticipation() { - return toBuilder().setOuterJoinParticipation(true).build(); + public By withOptionalParticipation() { + return toBuilder().setOptionalParticipation(true).build(); } } @@ -275,10 +277,10 @@ public class CoGroup { : joinArgsMap.get(tag).getFieldAccessDescriptor(); } - private boolean getOuterJoinParticipation(String tag) { + private boolean getOptionalParticipation(String tag) { return (allInputsJoinArgs != null) - ? allInputsJoinArgs.getOuterJoinParticipation() - : joinArgsMap.get(tag).getOuterJoinParticipation(); + ? allInputsJoinArgs.getOptionalParticipation() + : joinArgsMap.get(tag).getOptionalParticipation(); } } @@ -287,8 +289,8 @@ public class CoGroup { * * <p>The same fields and other options are used in all input PCollections. */ - public static Inner join(By clause) { - return new Inner(new JoinArguments(clause)); + public static Impl join(By clause) { + return new Impl(new JoinArguments(clause)); } /** @@ -297,8 +299,8 @@ public class CoGroup { * * <p>Each PCollection in the input must have args specified for the join key. */ - public static Inner join(String tag, By clause) { - return new Inner(new JoinArguments(ImmutableMap.of(tag, clause))); + public static Impl join(String tag, By clause) { + return new Impl(new JoinArguments(ImmutableMap.of(tag, clause))); } // Contains summary information needed for implementing the join. @@ -421,14 +423,14 @@ public class CoGroup { } /** The implementing PTransform. */ - public static class Inner extends PTransform<PCollectionTuple, PCollection<KV<Row, Row>>> { + public static class Impl extends PTransform<PCollectionTuple, PCollection<KV<Row, Row>>> { private final JoinArguments joinArgs; - private Inner() { + private Impl() { this(new JoinArguments(Collections.emptyMap())); } - private Inner(JoinArguments joinArgs) { + private Impl(JoinArguments joinArgs) { this.joinArgs = joinArgs; } @@ -437,11 +439,11 @@ public class CoGroup { * * <p>Each PCollection in the input must have fields specified for the join key. */ - public Inner join(String tag, By clause) { + public Impl join(String tag, By clause) { if (joinArgs.allInputsJoinArgs != null) { throw new IllegalStateException("Cannot set both a global and per-tag fields."); } - return new Inner(joinArgs.with(tag, clause)); + return new Impl(joinArgs.with(tag, clause)); } /** Expand the join into individual rows, similar to SQL joins. */ @@ -546,12 +548,12 @@ public class CoGroup { private Schema getOutputSchema(JoinInformation joinInformation) { // Construct the output schema. It contains one field for each input PCollection, of type - // ROW. If a field supports outer-join semantics, then that field will be nullable in the + // ROW. If a field has optional participation, then that field will be nullable in the // schema. Schema.Builder joinedSchemaBuilder = Schema.builder(); for (Map.Entry<String, Schema> entry : joinInformation.componentSchemas.entrySet()) { FieldType fieldType = FieldType.row(entry.getValue()); - if (joinArgs.getOuterJoinParticipation(entry.getKey())) { + if (joinArgs.getOptionalParticipation(entry.getKey())) { fieldType = fieldType.withNullable(true); } joinedSchemaBuilder.addField(entry.getKey(), fieldType); @@ -613,8 +615,8 @@ public class CoGroup { for (int i = 0; i < sortedTags.size(); ++i) { String tag = sortedTags.get(i); Iterable items = gbkResult.getAll(tagToKeyedTag.get(i)); - if (!items.iterator().hasNext() && joinArgs.getOuterJoinParticipation(tag)) { - // If this tag has outer-join participation, then empty should participate as a + if (!items.iterator().hasNext() && joinArgs.getOptionalParticipation(tag)) { + // If this tag has optional participation, then empty should participate as a // single null. items = () -> NULL_LIST.iterator(); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Join.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Join.java new file mode 100644 index 0000000..f1b7ee6 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Join.java @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.schemas.transforms; + +import java.io.Serializable; +import javax.annotation.Nullable; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; + +/** + * A transform that performs equijoins across two schema {@link PCollection}s. + * + * <p>This transform allows joins between two input PCollections simply by specifying the fields to + * join on. The resulting {@code PCollection<Row>} will have two fields named "lhs" and "rhs" + * respectively, each with the schema of the corresponding input PCollection. + * + * <p>For example, the following demonstrates joining two PCollections using a natural join on the + * "user" and "country" fields, where both the left-hand and the right-hand PCollections have fields + * with these names. + * + * <pre> + * {@code PCollection<Row> joined = pCollection1.apply(Join.innerJoin(pCollection2).using("user", "country")); + * }</pre> + * + * <p>If the right-hand PCollection contains fields with different names to join against, you can + * specify them as follows: + * + * <pre>{@code PCollection<Row> joined = pCollection1.apply(Join.innerJoin(pCollection2) + * .on(FieldsEqual.left("user", "country").right("otherUser", "otherCountry"))); + * }</pre> + * + * <p>Full outer joins, left outer joins, and right outer joins are also supported. + */ +@Experimental(Experimental.Kind.SCHEMAS) +public class Join { + public static final String LHS_TAG = "lhs"; + public static final String RHS_TAG = "rhs"; + + /** Predicate object to specify fields to compare when doing an equi-join. */ + public static class FieldsEqual { + public static Impl left(String... fieldNames) { + return new Impl( + FieldAccessDescriptor.withFieldNames(fieldNames), FieldAccessDescriptor.create()); + } + + public static Impl left(Integer... fieldIds) { + return new Impl(FieldAccessDescriptor.withFieldIds(fieldIds), FieldAccessDescriptor.create()); + } + + public static Impl left(FieldAccessDescriptor fieldAccessDescriptor) { + return new Impl(fieldAccessDescriptor, FieldAccessDescriptor.create()); + } + + public Impl right(String... fieldNames) { + return new Impl( + FieldAccessDescriptor.create(), FieldAccessDescriptor.withFieldNames(fieldNames)); + } + + public Impl right(Integer... fieldIds) { + return new Impl(FieldAccessDescriptor.create(), FieldAccessDescriptor.withFieldIds(fieldIds)); + } + + public Impl right(FieldAccessDescriptor fieldAccessDescriptor) { + return new Impl(FieldAccessDescriptor.create(), fieldAccessDescriptor); + } + + /** Implementation class for FieldsEqual. */ + public static class Impl implements Serializable { + private FieldAccessDescriptor lhs; + private FieldAccessDescriptor rhs; + + private Impl(FieldAccessDescriptor lhs, FieldAccessDescriptor rhs) { + this.lhs = lhs; + this.rhs = rhs; + } + + public Impl left(String... fieldNames) { + return new Impl(FieldAccessDescriptor.withFieldNames(fieldNames), rhs); + } + + public Impl left(Integer... fieldIds) { + return new Impl(FieldAccessDescriptor.withFieldIds(fieldIds), rhs); + } + + public Impl left(FieldAccessDescriptor fieldAccessDescriptor) { + return new Impl(fieldAccessDescriptor, rhs); + } + + public Impl right(String... fieldNames) { + return new Impl(lhs, FieldAccessDescriptor.withFieldNames(fieldNames)); + } + + public Impl right(Integer... fieldIds) { + return new Impl(lhs, FieldAccessDescriptor.withFieldIds(fieldIds)); + } + + public Impl right(FieldAccessDescriptor fieldAccessDescriptor) { + return new Impl(lhs, fieldAccessDescriptor); + } + + private Impl resolve(Schema lhsSchema, Schema rhsSchema) { + return new Impl(lhs.resolve(lhsSchema), rhs.resolve(rhsSchema)); + } + } + } + + /** Perform an inner join. */ + public static <LhsT, RhsT> Impl<LhsT, RhsT> innerJoin(PCollection<RhsT> rhs) { + return new Impl<>(JoinType.INNER, rhs); + } + + /** Perform a full outer join. */ + public static <LhsT, RhsT> Impl<LhsT, RhsT> fullOuterJoin(PCollection<RhsT> rhs) { + return new Impl<>(JoinType.OUTER, rhs); + } + + /** Perform a left outer join. */ + public static <LhsT, RhsT> Impl<LhsT, RhsT> leftOuterJoin(PCollection<RhsT> rhs) { + return new Impl<>(JoinType.LEFT_OUTER, rhs); + } + + /** Perform a right outer join. */ + public static <LhsT, RhsT> Impl<LhsT, RhsT> rightOuterJoin(PCollection<RhsT> rhs) { + return new Impl<>(JoinType.RIGHT_OUTER, rhs); + }; + + private enum JoinType { + INNER, + OUTER, + LEFT_OUTER, + RIGHT_OUTER + }; + + /** Implementation class . */ + public static class Impl<LhsT, RhsT> extends PTransform<PCollection<LhsT>, PCollection<Row>> { + private final JoinType joinType; + private final transient PCollection<RhsT> rhs; + @Nullable private final FieldsEqual.Impl predicate; + + private Impl(JoinType joinType, PCollection<RhsT> rhs) { + this(joinType, rhs, null); + } + + private Impl(JoinType joinType, PCollection<RhsT> rhs, FieldsEqual.Impl predicate) { + this.joinType = joinType; + this.rhs = rhs; + this.predicate = predicate; + } + + /** + * Perform a natural join between the PCollections. The fields are expected to exist in both + * PCollections + */ + public Impl<LhsT, RhsT> using(String... fieldNames) { + return new Impl<>(joinType, rhs, FieldsEqual.left(fieldNames).right(fieldNames)); + } + + /** + * Perform a natural join between the PCollections. The fields are expected to exist in both + * PCollections + */ + public Impl<LhsT, RhsT> using(Integer... fieldIds) { + return new Impl<>(joinType, rhs, FieldsEqual.left(fieldIds).right(fieldIds)); + } + + /** + * Perform a natural join between the PCollections. The fields are expected to exist in both + * PCollections + */ + public Impl<LhsT, RhsT> using(FieldAccessDescriptor fieldAccessDescriptor) { + return new Impl<>( + joinType, rhs, FieldsEqual.left(fieldAccessDescriptor).right(fieldAccessDescriptor)); + } + + /** Join the PCollections using the provided predicate. */ + public Impl<LhsT, RhsT> on(FieldsEqual.Impl predicate) { + return new Impl<>(joinType, rhs, predicate); + } + + @Override + public PCollection<Row> expand(PCollection lhs) { + FieldsEqual.Impl resolvedPredicate = predicate.resolve(lhs.getSchema(), rhs.getSchema()); + PCollectionTuple tuple = PCollectionTuple.of(LHS_TAG, lhs).and(RHS_TAG, rhs); + switch (joinType) { + case INNER: + return tuple.apply( + CoGroup.join(LHS_TAG, CoGroup.By.fieldAccessDescriptor(resolvedPredicate.lhs)) + .join(RHS_TAG, CoGroup.By.fieldAccessDescriptor(resolvedPredicate.rhs)) + .crossProductJoin()); + case OUTER: + return tuple.apply( + CoGroup.join( + LHS_TAG, + CoGroup.By.fieldAccessDescriptor(resolvedPredicate.lhs) + .withOptionalParticipation()) + .join( + RHS_TAG, + CoGroup.By.fieldAccessDescriptor(resolvedPredicate.rhs) + .withOptionalParticipation()) + .crossProductJoin()); + case LEFT_OUTER: + return tuple.apply( + CoGroup.join(LHS_TAG, CoGroup.By.fieldAccessDescriptor(resolvedPredicate.lhs)) + .join( + RHS_TAG, + CoGroup.By.fieldAccessDescriptor(resolvedPredicate.rhs) + .withOptionalParticipation()) + .crossProductJoin()); + case RIGHT_OUTER: + return tuple.apply( + CoGroup.join( + LHS_TAG, + CoGroup.By.fieldAccessDescriptor(resolvedPredicate.lhs) + .withOptionalParticipation()) + .join(RHS_TAG, CoGroup.By.fieldAccessDescriptor(resolvedPredicate.rhs)) + .crossProductJoin()); + default: + throw new RuntimeException("Unexpected join type"); + } + } + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java index 2251c6c..184f215 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java @@ -23,9 +23,7 @@ import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; import static org.junit.Assert.assertThat; -import java.util.Arrays; import java.util.List; -import java.util.stream.Collectors; import org.apache.beam.sdk.TestUtils.KvMatcher; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; @@ -397,30 +395,6 @@ public class CoGroupTest { pipeline.run(); } - private List<Row> innerJoin( - List<Row> inputs1, - List<Row> inputs2, - List<Row> inputs3, - String[] keys1, - String[] keys2, - String[] keys3, - Schema expectedSchema) { - List<Row> joined = Lists.newArrayList(); - for (Row row1 : inputs1) { - for (Row row2 : inputs2) { - for (Row row3 : inputs3) { - List key1 = Arrays.stream(keys1).map(row1::getValue).collect(Collectors.toList()); - List key2 = Arrays.stream(keys2).map(row2::getValue).collect(Collectors.toList()); - List key3 = Arrays.stream(keys3).map(row3::getValue).collect(Collectors.toList()); - if (key1.equals(key2) && key2.equals(key3)) { - joined.add(Row.withSchema(expectedSchema).addValues(row1, row2, row3).build()); - } - } - } - } - return joined; - } - @Test @Category(NeedsRunner.class) public void testInnerJoin() { @@ -477,7 +451,7 @@ public class CoGroupTest { assertEquals(expectedSchema, joined.getSchema()); List<Row> expectedJoinedRows = - innerJoin( + JoinTestUtils.innerJoin( pc1Rows, pc2Rows, pc3Rows, @@ -545,14 +519,14 @@ public class CoGroupTest { PCollectionTuple.of("pc1", pc1, "pc2", pc2, "pc3", pc3) .apply( "CoGroup", - CoGroup.join("pc1", By.fieldNames("user", "country").withOuterJoinParticipation()) - .join("pc2", By.fieldNames("user2", "country2").withOuterJoinParticipation()) - .join("pc3", By.fieldNames("user3", "country3").withOuterJoinParticipation()) + CoGroup.join("pc1", By.fieldNames("user", "country").withOptionalParticipation()) + .join("pc2", By.fieldNames("user2", "country2").withOptionalParticipation()) + .join("pc3", By.fieldNames("user3", "country3").withOptionalParticipation()) .crossProductJoin()); assertEquals(expectedSchema, joined.getSchema()); List<Row> expectedJoinedRows = - innerJoin( + JoinTestUtils.innerJoin( pc1Rows, pc2Rows, pc3Rows, @@ -633,13 +607,13 @@ public class CoGroupTest { .apply( "CoGroup", CoGroup.join("pc1", By.fieldNames("user", "country")) - .join("pc2", By.fieldNames("user2", "country2").withOuterJoinParticipation()) + .join("pc2", By.fieldNames("user2", "country2").withOptionalParticipation()) .join("pc3", By.fieldNames("user3", "country3")) .crossProductJoin()); assertEquals(expectedSchema, joined.getSchema()); List<Row> expectedJoinedRows = - innerJoin( + JoinTestUtils.innerJoin( pc1Rows, pc2Rows, pc3Rows, diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/JoinTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/JoinTest.java new file mode 100644 index 0000000..0b2631f --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/JoinTest.java @@ -0,0 +1,390 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.schemas.transforms; + +import static junit.framework.TestCase.assertEquals; +import static org.apache.beam.sdk.schemas.transforms.JoinTestUtils.innerJoin; + +import java.util.List; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.transforms.Join.FieldsEqual; +import org.apache.beam.sdk.testing.NeedsRunner; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Lists; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +/** Tests for {@link org.apache.beam.sdk.schemas.transforms.Join}. */ +public class JoinTest { + @Rule public final transient TestPipeline pipeline = TestPipeline.create(); + + private static final Schema CG_SCHEMA_1 = + Schema.builder() + .addStringField("user") + .addInt32Field("count") + .addStringField("country") + .build(); + private static final Schema CG_SCHEMA_2 = + Schema.builder() + .addStringField("user2") + .addInt32Field("count2") + .addStringField("country2") + .build(); + + @Test + @Category(NeedsRunner.class) + public void testInnerJoinSameKeys() { + List<Row> pc1Rows = + Lists.newArrayList( + Row.withSchema(CG_SCHEMA_1).addValues("user1", 1, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 2, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 3, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 4, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 5, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 6, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user3", 8, "ar").build()); + List<Row> pc2Rows = + Lists.newArrayList( + Row.withSchema(CG_SCHEMA_1).addValues("user1", 9, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 10, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 11, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 12, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 13, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 14, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 15, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 16, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user4", 8, "ar").build()); + + PCollection<Row> pc1 = pipeline.apply("Create1", Create.of(pc1Rows)).setRowSchema(CG_SCHEMA_1); + PCollection<Row> pc2 = pipeline.apply("Create2", Create.of(pc2Rows)).setRowSchema(CG_SCHEMA_1); + + Schema expectedSchema = + Schema.builder() + .addRowField(Join.LHS_TAG, CG_SCHEMA_1) + .addRowField(Join.RHS_TAG, CG_SCHEMA_1) + .build(); + + PCollection<Row> joined = pc1.apply(Join.<Row, Row>innerJoin(pc2).using("user", "country")); + + assertEquals(expectedSchema, joined.getSchema()); + + List<Row> expectedJoinedRows = + innerJoin( + pc1Rows, + pc2Rows, + new String[] {"user", "country"}, + new String[] {"user", "country"}, + expectedSchema); + + PAssert.that(joined).containsInAnyOrder(expectedJoinedRows); + pipeline.run(); + } + + @Test + @Category(NeedsRunner.class) + public void testInnerJoinDifferentKeys() { + List<Row> pc1Rows = + Lists.newArrayList( + Row.withSchema(CG_SCHEMA_1).addValues("user1", 1, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 2, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 3, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 4, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 5, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 6, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user3", 8, "ar").build()); + List<Row> pc2Rows = + Lists.newArrayList( + Row.withSchema(CG_SCHEMA_2).addValues("user1", 9, "us").build(), + Row.withSchema(CG_SCHEMA_2).addValues("user1", 10, "us").build(), + Row.withSchema(CG_SCHEMA_2).addValues("user1", 11, "il").build(), + Row.withSchema(CG_SCHEMA_2).addValues("user1", 12, "il").build(), + Row.withSchema(CG_SCHEMA_2).addValues("user2", 13, "fr").build(), + Row.withSchema(CG_SCHEMA_2).addValues("user2", 14, "fr").build(), + Row.withSchema(CG_SCHEMA_2).addValues("user2", 15, "ar").build(), + Row.withSchema(CG_SCHEMA_2).addValues("user2", 16, "ar").build(), + Row.withSchema(CG_SCHEMA_2).addValues("user4", 8, "ar").build()); + + PCollection<Row> pc1 = pipeline.apply("Create1", Create.of(pc1Rows)).setRowSchema(CG_SCHEMA_1); + PCollection<Row> pc2 = pipeline.apply("Create2", Create.of(pc2Rows)).setRowSchema(CG_SCHEMA_2); + + Schema expectedSchema = + Schema.builder() + .addRowField(Join.LHS_TAG, CG_SCHEMA_1) + .addRowField(Join.RHS_TAG, CG_SCHEMA_2) + .build(); + + PCollection<Row> joined = + pc1.apply( + Join.<Row, Row>innerJoin(pc2) + .on(FieldsEqual.left("user", "country").right("user2", "country2"))); + + assertEquals(expectedSchema, joined.getSchema()); + + List<Row> expectedJoinedRows = + innerJoin( + pc1Rows, + pc2Rows, + new String[] {"user", "country"}, + new String[] {"user2", "country2"}, + expectedSchema); + + PAssert.that(joined).containsInAnyOrder(expectedJoinedRows); + pipeline.run(); + } + + @Test + @Category(NeedsRunner.class) + public void testOuterJoinDifferentKeys() { + List<Row> pc1Rows = + Lists.newArrayList( + Row.withSchema(CG_SCHEMA_1).addValues("user1", 1, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 2, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 3, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 4, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 5, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 6, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user3", 8, "ar").build()); + List<Row> pc2Rows = + Lists.newArrayList( + Row.withSchema(CG_SCHEMA_2).addValues("user1", 9, "us").build(), + Row.withSchema(CG_SCHEMA_2).addValues("user1", 10, "us").build(), + Row.withSchema(CG_SCHEMA_2).addValues("user1", 11, "il").build(), + Row.withSchema(CG_SCHEMA_2).addValues("user1", 12, "il").build(), + Row.withSchema(CG_SCHEMA_2).addValues("user2", 13, "fr").build(), + Row.withSchema(CG_SCHEMA_2).addValues("user2", 14, "fr").build(), + Row.withSchema(CG_SCHEMA_2).addValues("user2", 15, "ar").build(), + Row.withSchema(CG_SCHEMA_2).addValues("user2", 16, "ar").build(), + Row.withSchema(CG_SCHEMA_2).addValues("user4", 8, "ar").build()); + + PCollection<Row> pc1 = pipeline.apply("Create1", Create.of(pc1Rows)).setRowSchema(CG_SCHEMA_1); + PCollection<Row> pc2 = pipeline.apply("Create2", Create.of(pc2Rows)).setRowSchema(CG_SCHEMA_2); + + Schema expectedSchema = + Schema.builder() + .addNullableField(Join.LHS_TAG, Schema.FieldType.row(CG_SCHEMA_1)) + .addNullableField(Join.RHS_TAG, Schema.FieldType.row(CG_SCHEMA_2)) + .build(); + + PCollection<Row> joined = + pc1.apply( + Join.<Row, Row>fullOuterJoin(pc2) + .on(FieldsEqual.left("user", "country").right("user2", "country2"))); + + assertEquals(expectedSchema, joined.getSchema()); + + List<Row> expectedJoinedRows = + innerJoin( + pc1Rows, + pc2Rows, + new String[] {"user", "country"}, + new String[] {"user2", "country2"}, + expectedSchema); + expectedJoinedRows.add( + Row.withSchema(expectedSchema) + .addValues(Row.withSchema(CG_SCHEMA_1).addValues("user3", 8, "ar").build(), null) + .build()); + expectedJoinedRows.add( + Row.withSchema(expectedSchema) + .addValues(null, Row.withSchema(CG_SCHEMA_2).addValues("user4", 8, "ar").build()) + .build()); + + PAssert.that(joined).containsInAnyOrder(expectedJoinedRows); + pipeline.run(); + } + + @Test + @Category(NeedsRunner.class) + public void testOuterJoinSameKeys() { + List<Row> pc1Rows = + Lists.newArrayList( + Row.withSchema(CG_SCHEMA_1).addValues("user1", 1, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 2, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 3, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 4, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 5, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 6, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user3", 8, "ar").build()); + List<Row> pc2Rows = + Lists.newArrayList( + Row.withSchema(CG_SCHEMA_1).addValues("user1", 9, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 10, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 11, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 12, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 13, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 14, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 15, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 16, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user4", 8, "ar").build()); + + PCollection<Row> pc1 = pipeline.apply("Create1", Create.of(pc1Rows)).setRowSchema(CG_SCHEMA_1); + PCollection<Row> pc2 = pipeline.apply("Create2", Create.of(pc2Rows)).setRowSchema(CG_SCHEMA_1); + + Schema expectedSchema = + Schema.builder() + .addNullableField(Join.LHS_TAG, Schema.FieldType.row(CG_SCHEMA_1)) + .addNullableField(Join.RHS_TAG, Schema.FieldType.row(CG_SCHEMA_1)) + .build(); + + PCollection<Row> joined = pc1.apply(Join.<Row, Row>fullOuterJoin(pc2).using("user", "country")); + + assertEquals(expectedSchema, joined.getSchema()); + + List<Row> expectedJoinedRows = + innerJoin( + pc1Rows, + pc2Rows, + new String[] {"user", "country"}, + new String[] {"user", "country"}, + expectedSchema); + expectedJoinedRows.add( + Row.withSchema(expectedSchema) + .addValues(Row.withSchema(CG_SCHEMA_1).addValues("user3", 8, "ar").build(), null) + .build()); + expectedJoinedRows.add( + Row.withSchema(expectedSchema) + .addValues(null, Row.withSchema(CG_SCHEMA_1).addValues("user4", 8, "ar").build()) + .build()); + + PAssert.that(joined).containsInAnyOrder(expectedJoinedRows); + pipeline.run(); + } + + @Test + @Category(NeedsRunner.class) + public void testLeftOuterJoinSameKeys() { + List<Row> pc1Rows = + Lists.newArrayList( + Row.withSchema(CG_SCHEMA_1).addValues("user1", 1, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 2, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 3, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 4, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 5, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 6, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user3", 8, "ar").build()); + List<Row> pc2Rows = + Lists.newArrayList( + Row.withSchema(CG_SCHEMA_1).addValues("user1", 9, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 10, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 11, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 12, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 13, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 14, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 15, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 16, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user4", 8, "ar").build()); + + PCollection<Row> pc1 = pipeline.apply("Create1", Create.of(pc1Rows)).setRowSchema(CG_SCHEMA_1); + PCollection<Row> pc2 = pipeline.apply("Create2", Create.of(pc2Rows)).setRowSchema(CG_SCHEMA_1); + + Schema expectedSchema = + Schema.builder() + .addField(Join.LHS_TAG, Schema.FieldType.row(CG_SCHEMA_1)) + .addNullableField(Join.RHS_TAG, Schema.FieldType.row(CG_SCHEMA_1)) + .build(); + + PCollection<Row> joined = pc1.apply(Join.<Row, Row>leftOuterJoin(pc2).using("user", "country")); + + assertEquals(expectedSchema, joined.getSchema()); + + List<Row> expectedJoinedRows = + innerJoin( + pc1Rows, + pc2Rows, + new String[] {"user", "country"}, + new String[] {"user", "country"}, + expectedSchema); + expectedJoinedRows.add( + Row.withSchema(expectedSchema) + .addValues(Row.withSchema(CG_SCHEMA_1).addValues("user3", 8, "ar").build(), null) + .build()); + + PAssert.that(joined).containsInAnyOrder(expectedJoinedRows); + pipeline.run(); + } + + @Test + @Category(NeedsRunner.class) + public void testRightOuterJoinSameKeys() { + List<Row> pc1Rows = + Lists.newArrayList( + Row.withSchema(CG_SCHEMA_1).addValues("user1", 1, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 2, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 3, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 4, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 5, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 6, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 7, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 8, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user3", 8, "ar").build()); + List<Row> pc2Rows = + Lists.newArrayList( + Row.withSchema(CG_SCHEMA_1).addValues("user1", 9, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 10, "us").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 11, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user1", 12, "il").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 13, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 14, "fr").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 15, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user2", 16, "ar").build(), + Row.withSchema(CG_SCHEMA_1).addValues("user4", 8, "ar").build()); + + PCollection<Row> pc1 = pipeline.apply("Create1", Create.of(pc1Rows)).setRowSchema(CG_SCHEMA_1); + PCollection<Row> pc2 = pipeline.apply("Create2", Create.of(pc2Rows)).setRowSchema(CG_SCHEMA_1); + + Schema expectedSchema = + Schema.builder() + .addNullableField(Join.LHS_TAG, Schema.FieldType.row(CG_SCHEMA_1)) + .addField(Join.RHS_TAG, Schema.FieldType.row(CG_SCHEMA_1)) + .build(); + + PCollection<Row> joined = + pc1.apply(Join.<Row, Row>rightOuterJoin(pc2).using("user", "country")); + + assertEquals(expectedSchema, joined.getSchema()); + + List<Row> expectedJoinedRows = + innerJoin( + pc1Rows, + pc2Rows, + new String[] {"user", "country"}, + new String[] {"user", "country"}, + expectedSchema); + expectedJoinedRows.add( + Row.withSchema(expectedSchema) + .addValues(null, Row.withSchema(CG_SCHEMA_1).addValues("user4", 8, "ar").build()) + .build()); + + PAssert.that(joined).containsInAnyOrder(expectedJoinedRows); + pipeline.run(); + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/JoinTestUtils.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/JoinTestUtils.java new file mode 100644 index 0000000..2ef6370 --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/JoinTestUtils.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.schemas.transforms; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Lists; + +class JoinTestUtils { + static List<Row> innerJoin( + List<Row> inputs1, List<Row> inputs2, String[] keys1, String[] keys2, Schema expectedSchema) { + List<Row> joined = Lists.newArrayList(); + for (Row row1 : inputs1) { + for (Row row2 : inputs2) { + List key1 = Arrays.stream(keys1).map(row1::getValue).collect(Collectors.toList()); + List key2 = Arrays.stream(keys2).map(row2::getValue).collect(Collectors.toList()); + if (key1.equals(key2)) { + joined.add(Row.withSchema(expectedSchema).addValues(row1, row2).build()); + } + } + } + return joined; + } + + static List<Row> innerJoin( + List<Row> inputs1, + List<Row> inputs2, + List<Row> inputs3, + String[] keys1, + String[] keys2, + String[] keys3, + Schema expectedSchema) { + List<Row> joined = Lists.newArrayList(); + for (Row row1 : inputs1) { + for (Row row2 : inputs2) { + for (Row row3 : inputs3) { + List key1 = Arrays.stream(keys1).map(row1::getValue).collect(Collectors.toList()); + List key2 = Arrays.stream(keys2).map(row2::getValue).collect(Collectors.toList()); + List key3 = Arrays.stream(keys3).map(row3::getValue).collect(Collectors.toList()); + if (key1.equals(key2) && key2.equals(key3)) { + joined.add(Row.withSchema(expectedSchema).addValues(row1, row2, row3).build()); + } + } + } + } + return joined; + } +}