diff --git a/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/AbstractJoinTranslator.java b/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/AbstractJoinTranslator.java index d6df7a8d1c9e..7c2e8168f011 100644 --- a/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/AbstractJoinTranslator.java +++ b/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/AbstractJoinTranslator.java @@ -56,7 +56,8 @@ final Window<KV<KeyT, RightT>> rightWindow = (Window) operator.getWindow().get(); rightKeyed = rightKeyed.apply("window-right", rightWindow); } - return translate(operator, leftKeyed, rightKeyed) + + return translate(operator, left, leftKeyed, right, rightKeyed) .setTypeDescriptor( operator .getOutputType() @@ -66,6 +67,8 @@ abstract PCollection<KV<KeyT, OutputT>> translate( Join<LeftT, RightT, KeyT, OutputT> operator, - PCollection<KV<KeyT, LeftT>> left, - PCollection<KV<KeyT, RightT>> right); + PCollection<LeftT> left, + PCollection<KV<KeyT, LeftT>> leftKeyed, + PCollection<RightT> right, + PCollection<KV<KeyT, RightT>> rightKeyed); } diff --git a/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/BroadcastHashJoinTranslator.java b/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/BroadcastHashJoinTranslator.java index 90870a13e4db..05d923cb8aca 100644 --- a/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/BroadcastHashJoinTranslator.java +++ b/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/BroadcastHashJoinTranslator.java @@ -17,11 +17,15 @@ */ package org.apache.beam.sdk.extensions.euphoria.core.translate; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.HashBasedTable; +import com.google.common.collect.Table; import java.util.Collections; import java.util.Map; import javax.annotation.Nullable; import org.apache.beam.sdk.extensions.euphoria.core.client.accumulators.AccumulatorProvider; import org.apache.beam.sdk.extensions.euphoria.core.client.functional.BinaryFunctor; +import org.apache.beam.sdk.extensions.euphoria.core.client.functional.UnaryFunction; import org.apache.beam.sdk.extensions.euphoria.core.client.operator.Join; import org.apache.beam.sdk.extensions.euphoria.core.translate.collector.AdaptableCollector; import org.apache.beam.sdk.transforms.DoFn; @@ -35,22 +39,42 @@ * Translator for {@link org.apache.beam.sdk.extensions.euphoria.core.client.operator.RightJoin} and * {@link org.apache.beam.sdk.extensions.euphoria.core.client.operator.LeftJoin} when one side of * the join fits in memory so it can be distributed in hash map with the other side. + * + * <p>Note that when reusing smaller join side to several broadcast hash joins there are some rules + * to follow to avoid data to be send to executors repeatedly: + * + * <ul> + * <li>Input {@link PCollection} of broadcasted side has to be the same instance + * <li>Key extractor of broadcasted side has to be the same {@link UnaryFunction} instance + * </ul> */ public class BroadcastHashJoinTranslator<LeftT, RightT, KeyT, OutputT> extends AbstractJoinTranslator<LeftT, RightT, KeyT, OutputT> { + /** + * Used to prevent multiple views to the same input PCollection. And therefore multiple broadcasts + * of the same data. + */ + @VisibleForTesting + final Table<PCollection<?>, UnaryFunction<?, KeyT>, PCollectionView<?>> pViews = + HashBasedTable.create(); + @Override PCollection<KV<KeyT, OutputT>> translate( Join<LeftT, RightT, KeyT, OutputT> operator, - PCollection<KV<KeyT, LeftT>> left, - PCollection<KV<KeyT, RightT>> right) { + PCollection<LeftT> left, + PCollection<KV<KeyT, LeftT>> leftKeyed, + PCollection<RightT> right, + PCollection<KV<KeyT, RightT>> rightKeyed) { + final AccumulatorProvider accumulators = new LazyAccumulatorProvider(AccumulatorProvider.of(left.getPipeline())); + switch (operator.getType()) { case LEFT: final PCollectionView<Map<KeyT, Iterable<RightT>>> broadcastRight = - right.apply(View.asMultimap()); - return left.apply( + computeViewAsMultimapIfAbsent(right, operator.getRightKeyExtractor(), rightKeyed); + return leftKeyed.apply( ParDo.of( new BroadcastHashLeftJoinFn<>( broadcastRight, @@ -60,8 +84,8 @@ .withSideInputs(broadcastRight)); case RIGHT: final PCollectionView<Map<KeyT, Iterable<LeftT>>> broadcastLeft = - left.apply(View.asMultimap()); - return right.apply( + computeViewAsMultimapIfAbsent(left, operator.getLeftKeyExtractor(), leftKeyed); + return rightKeyed.apply( ParDo.of( new BroadcastHashRightJoinFn<>( broadcastLeft, @@ -78,6 +102,31 @@ } } + /** + * Creates new {@link PCollectionView} of given {@code pCollectionToView} iff there is no {@link + * PCollectionView} already associated with {@code Key}. + * + * @param pCollectionToView a {@link PCollection} view will be created from by applying {@link + * View#asMultimap()} + * @param <V> value key type + * @return the current (already existing or computed) value associated with the specified key + */ + private <V> PCollectionView<Map<KeyT, Iterable<V>>> computeViewAsMultimapIfAbsent( + PCollection<V> pcollection, + UnaryFunction<?, KeyT> keyExtractor, + final PCollection<KV<KeyT, V>> pCollectionToView) { + + PCollectionView<?> view = pViews.get(pcollection, keyExtractor); + if (view == null) { + view = pCollectionToView.apply(View.asMultimap()); + pViews.put(pcollection, keyExtractor, view); + } + + @SuppressWarnings("unchecked") + PCollectionView<Map<KeyT, Iterable<V>>> ret = (PCollectionView<Map<KeyT, Iterable<V>>>) view; + return ret; + } + static class BroadcastHashRightJoinFn<K, LeftT, RightT, OutputT> extends DoFn<KV<K, RightT>, KV<K, OutputT>> { diff --git a/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/JoinTranslator.java b/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/JoinTranslator.java index 09a22dfa8e62..e967888ff13c 100644 --- a/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/JoinTranslator.java +++ b/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/JoinTranslator.java @@ -244,16 +244,18 @@ public String getFnName() { @Override PCollection<KV<KeyT, OutputT>> translate( Join<LeftT, RightT, KeyT, OutputT> operator, - PCollection<KV<KeyT, LeftT>> left, - PCollection<KV<KeyT, RightT>> right) { + PCollection<LeftT> left, + PCollection<KV<KeyT, LeftT>> leftKeyed, + PCollection<RightT> reight, + PCollection<KV<KeyT, RightT>> rightKeyed) { final AccumulatorProvider accumulators = - new LazyAccumulatorProvider(AccumulatorProvider.of(left.getPipeline())); + new LazyAccumulatorProvider(AccumulatorProvider.of(leftKeyed.getPipeline())); final TupleTag<LeftT> leftTag = new TupleTag<>(); final TupleTag<RightT> rightTag = new TupleTag<>(); final JoinFn<LeftT, RightT, KeyT, OutputT> joinFn = getJoinFn(operator, leftTag, rightTag, accumulators); - return KeyedPCollectionTuple.of(leftTag, left) - .and(rightTag, right) + return KeyedPCollectionTuple.of(leftTag, leftKeyed) + .and(rightTag, rightKeyed) .apply("co-group-by-key", CoGroupByKey.create()) .apply(joinFn.getFnName(), ParDo.of(joinFn)); } diff --git a/sdks/java/extensions/euphoria/src/test/java/org/apache/beam/sdk/extensions/euphoria/core/translate/BroadcastHashJoinTranslatorTest.java b/sdks/java/extensions/euphoria/src/test/java/org/apache/beam/sdk/extensions/euphoria/core/translate/BroadcastHashJoinTranslatorTest.java new file mode 100644 index 000000000000..daff0621a462 --- /dev/null +++ b/sdks/java/extensions/euphoria/src/test/java/org/apache/beam/sdk/extensions/euphoria/core/translate/BroadcastHashJoinTranslatorTest.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.euphoria.core.translate; + +import org.apache.beam.sdk.extensions.euphoria.core.client.functional.UnaryFunction; +import org.apache.beam.sdk.extensions.euphoria.core.client.operator.Join; +import org.apache.beam.sdk.extensions.euphoria.core.client.operator.LeftJoin; +import org.apache.beam.sdk.extensions.euphoria.core.client.operator.RightJoin; +import org.apache.beam.sdk.extensions.euphoria.core.translate.provider.CompositeProvider; +import org.apache.beam.sdk.extensions.euphoria.core.translate.provider.GenericTranslatorProvider; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptors; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; + +/** Unit tests of {@link BroadcastHashJoinTranslator}. */ +public class BroadcastHashJoinTranslatorTest { + + @Rule public TestPipeline p = TestPipeline.create(); + + @Test + public void twoUsesOneViewTest() { + + BroadcastHashJoinTranslator<?, ?, ?, ?> translatorUnderTest = + new BroadcastHashJoinTranslator<>(); + + EuphoriaOptions options = p.getOptions().as(EuphoriaOptions.class); + // Every join in this test will be translated as Broadcast + options.setTranslatorProvider( + CompositeProvider.of( + GenericTranslatorProvider.newBuilder() + .register(Join.class, (op) -> true, translatorUnderTest) + .build(), + GenericTranslatorProvider.createWithDefaultTranslators())); + + // create input to be broadcast + PCollection<KV<Integer, String>> lengthStrings = + p.apply("names", + Create.of(KV.of(1, "one"), KV.of(2, "two"), KV.of(3, "three"))) + .setTypeDescriptor( + TypeDescriptors.kvs(TypeDescriptors.integers(), TypeDescriptors.strings())); + + UnaryFunction<KV<Integer, String>, Integer> sharedKeyExtractor = KV::getKey; + + // other datasets to be joined with + PCollection<String> letters = + p.apply("letters", Create.of("a", "b", "c", "d")).setTypeDescriptor(TypeDescriptors.strings()); + PCollection<String> acronyms = + p.apply("acronyms", Create.of("B2K", "DIY", "FKA", "EOBD")).setTypeDescriptor(TypeDescriptors.strings()); + + PCollection<KV<Integer, String>> lettersJoined = + LeftJoin.named("join-letters-with-lengths") + .of(letters, lengthStrings) + .by(String::length, sharedKeyExtractor, TypeDescriptors.integers()) + .using( + (letter, maybeLength, ctx) -> + ctx.collect(letter + "-" + maybeLength.orElse(KV.of(-1, "null")).getValue()), + TypeDescriptors.strings()) + .output(); + + PCollection<KV<Integer, String>> acronymsJoined = + RightJoin.named("join-acronyms-with-lengths") + .of(lengthStrings, acronyms) + .by(sharedKeyExtractor, String::length, TypeDescriptors.integers()) + .using( + (maybeLength, acronym, ctx) -> + ctx.collect(maybeLength.orElse(KV.of(-1, "null")).getValue() + "-" + acronym), + TypeDescriptors.strings()) + .output(); + + + PAssert.that(lettersJoined) + .containsInAnyOrder( + KV.of(1, "a-one"), KV.of(1, "b-one"), KV.of(1, "c-one"), KV.of(1, "d-one")); + + PAssert.that(acronymsJoined) + .containsInAnyOrder( + KV.of(3, "three-B2K"), + KV.of(3, "three-DIY"), + KV.of(3, "three-FKA"), + KV.of(4, "null-EOBD")); + + p.run(); + Assert.assertEquals(1, translatorUnderTest.pViews.size()); + } +}
With regards, Apache Git Services