diff --git 
a/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/AbstractJoinTranslator.java
 
b/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/AbstractJoinTranslator.java
index d6df7a8d1c9e..7c2e8168f011 100644
--- 
a/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/AbstractJoinTranslator.java
+++ 
b/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/AbstractJoinTranslator.java
@@ -56,7 +56,8 @@
       final Window<KV<KeyT, RightT>> rightWindow = (Window) 
operator.getWindow().get();
       rightKeyed = rightKeyed.apply("window-right", rightWindow);
     }
-    return translate(operator, leftKeyed, rightKeyed)
+
+    return translate(operator, left, leftKeyed, right, rightKeyed)
         .setTypeDescriptor(
             operator
                 .getOutputType()
@@ -66,6 +67,8 @@
 
   abstract PCollection<KV<KeyT, OutputT>> translate(
       Join<LeftT, RightT, KeyT, OutputT> operator,
-      PCollection<KV<KeyT, LeftT>> left,
-      PCollection<KV<KeyT, RightT>> right);
+      PCollection<LeftT> left,
+      PCollection<KV<KeyT, LeftT>> leftKeyed,
+      PCollection<RightT> right,
+      PCollection<KV<KeyT, RightT>> rightKeyed);
 }
diff --git 
a/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/BroadcastHashJoinTranslator.java
 
b/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/BroadcastHashJoinTranslator.java
index 90870a13e4db..05d923cb8aca 100644
--- 
a/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/BroadcastHashJoinTranslator.java
+++ 
b/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/BroadcastHashJoinTranslator.java
@@ -17,11 +17,15 @@
  */
 package org.apache.beam.sdk.extensions.euphoria.core.translate;
 
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.HashBasedTable;
+import com.google.common.collect.Table;
 import java.util.Collections;
 import java.util.Map;
 import javax.annotation.Nullable;
 import 
org.apache.beam.sdk.extensions.euphoria.core.client.accumulators.AccumulatorProvider;
 import 
org.apache.beam.sdk.extensions.euphoria.core.client.functional.BinaryFunctor;
+import 
org.apache.beam.sdk.extensions.euphoria.core.client.functional.UnaryFunction;
 import org.apache.beam.sdk.extensions.euphoria.core.client.operator.Join;
 import 
org.apache.beam.sdk.extensions.euphoria.core.translate.collector.AdaptableCollector;
 import org.apache.beam.sdk.transforms.DoFn;
@@ -35,22 +39,42 @@
  * Translator for {@link 
org.apache.beam.sdk.extensions.euphoria.core.client.operator.RightJoin} and
  * {@link 
org.apache.beam.sdk.extensions.euphoria.core.client.operator.LeftJoin} when one 
side of
  * the join fits in memory so it can be distributed in hash map with the other 
side.
+ *
+ * <p>Note that when reusing smaller join side to several broadcast hash joins 
there are some rules
+ * to follow to avoid data to be send to executors repeatedly:
+ *
+ * <ul>
+ *   <li>Input {@link PCollection} of broadcasted side has to be the same 
instance
+ *   <li>Key extractor of broadcasted side has to be the same {@link 
UnaryFunction} instance
+ * </ul>
  */
 public class BroadcastHashJoinTranslator<LeftT, RightT, KeyT, OutputT>
     extends AbstractJoinTranslator<LeftT, RightT, KeyT, OutputT> {
 
+  /**
+   * Used to prevent multiple views to the same input PCollection. And 
therefore multiple broadcasts
+   * of the same data.
+   */
+  @VisibleForTesting
+  final Table<PCollection<?>, UnaryFunction<?, KeyT>, PCollectionView<?>> 
pViews =
+      HashBasedTable.create();
+
   @Override
   PCollection<KV<KeyT, OutputT>> translate(
       Join<LeftT, RightT, KeyT, OutputT> operator,
-      PCollection<KV<KeyT, LeftT>> left,
-      PCollection<KV<KeyT, RightT>> right) {
+      PCollection<LeftT> left,
+      PCollection<KV<KeyT, LeftT>> leftKeyed,
+      PCollection<RightT> right,
+      PCollection<KV<KeyT, RightT>> rightKeyed) {
+
     final AccumulatorProvider accumulators =
         new 
LazyAccumulatorProvider(AccumulatorProvider.of(left.getPipeline()));
+
     switch (operator.getType()) {
       case LEFT:
         final PCollectionView<Map<KeyT, Iterable<RightT>>> broadcastRight =
-            right.apply(View.asMultimap());
-        return left.apply(
+            computeViewAsMultimapIfAbsent(right, 
operator.getRightKeyExtractor(), rightKeyed);
+        return leftKeyed.apply(
             ParDo.of(
                     new BroadcastHashLeftJoinFn<>(
                         broadcastRight,
@@ -60,8 +84,8 @@
                 .withSideInputs(broadcastRight));
       case RIGHT:
         final PCollectionView<Map<KeyT, Iterable<LeftT>>> broadcastLeft =
-            left.apply(View.asMultimap());
-        return right.apply(
+            computeViewAsMultimapIfAbsent(left, 
operator.getLeftKeyExtractor(), leftKeyed);
+        return rightKeyed.apply(
             ParDo.of(
                     new BroadcastHashRightJoinFn<>(
                         broadcastLeft,
@@ -78,6 +102,31 @@
     }
   }
 
+  /**
+   * Creates new {@link PCollectionView} of given {@code pCollectionToView} 
iff there is no {@link
+   * PCollectionView} already associated with {@code Key}.
+   *
+   * @param pCollectionToView a {@link PCollection} view will be created from 
by applying {@link
+   *     View#asMultimap()}
+   * @param <V> value key type
+   * @return the current (already existing or computed) value associated with 
the specified key
+   */
+  private <V> PCollectionView<Map<KeyT, Iterable<V>>> 
computeViewAsMultimapIfAbsent(
+      PCollection<V> pcollection,
+      UnaryFunction<?, KeyT> keyExtractor,
+      final PCollection<KV<KeyT, V>> pCollectionToView) {
+
+    PCollectionView<?> view = pViews.get(pcollection, keyExtractor);
+    if (view == null) {
+      view = pCollectionToView.apply(View.asMultimap());
+      pViews.put(pcollection, keyExtractor, view);
+    }
+
+    @SuppressWarnings("unchecked")
+    PCollectionView<Map<KeyT, Iterable<V>>> ret = (PCollectionView<Map<KeyT, 
Iterable<V>>>) view;
+    return ret;
+  }
+
   static class BroadcastHashRightJoinFn<K, LeftT, RightT, OutputT>
       extends DoFn<KV<K, RightT>, KV<K, OutputT>> {
 
diff --git 
a/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/JoinTranslator.java
 
b/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/JoinTranslator.java
index 09a22dfa8e62..e967888ff13c 100644
--- 
a/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/JoinTranslator.java
+++ 
b/sdks/java/extensions/euphoria/src/main/java/org/apache/beam/sdk/extensions/euphoria/core/translate/JoinTranslator.java
@@ -244,16 +244,18 @@ public String getFnName() {
   @Override
   PCollection<KV<KeyT, OutputT>> translate(
       Join<LeftT, RightT, KeyT, OutputT> operator,
-      PCollection<KV<KeyT, LeftT>> left,
-      PCollection<KV<KeyT, RightT>> right) {
+      PCollection<LeftT> left,
+      PCollection<KV<KeyT, LeftT>> leftKeyed,
+      PCollection<RightT> reight,
+      PCollection<KV<KeyT, RightT>> rightKeyed) {
     final AccumulatorProvider accumulators =
-        new 
LazyAccumulatorProvider(AccumulatorProvider.of(left.getPipeline()));
+        new 
LazyAccumulatorProvider(AccumulatorProvider.of(leftKeyed.getPipeline()));
     final TupleTag<LeftT> leftTag = new TupleTag<>();
     final TupleTag<RightT> rightTag = new TupleTag<>();
     final JoinFn<LeftT, RightT, KeyT, OutputT> joinFn =
         getJoinFn(operator, leftTag, rightTag, accumulators);
-    return KeyedPCollectionTuple.of(leftTag, left)
-        .and(rightTag, right)
+    return KeyedPCollectionTuple.of(leftTag, leftKeyed)
+        .and(rightTag, rightKeyed)
         .apply("co-group-by-key", CoGroupByKey.create())
         .apply(joinFn.getFnName(), ParDo.of(joinFn));
   }
diff --git 
a/sdks/java/extensions/euphoria/src/test/java/org/apache/beam/sdk/extensions/euphoria/core/translate/BroadcastHashJoinTranslatorTest.java
 
b/sdks/java/extensions/euphoria/src/test/java/org/apache/beam/sdk/extensions/euphoria/core/translate/BroadcastHashJoinTranslatorTest.java
new file mode 100644
index 000000000000..daff0621a462
--- /dev/null
+++ 
b/sdks/java/extensions/euphoria/src/test/java/org/apache/beam/sdk/extensions/euphoria/core/translate/BroadcastHashJoinTranslatorTest.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.euphoria.core.translate;
+
+import 
org.apache.beam.sdk.extensions.euphoria.core.client.functional.UnaryFunction;
+import org.apache.beam.sdk.extensions.euphoria.core.client.operator.Join;
+import org.apache.beam.sdk.extensions.euphoria.core.client.operator.LeftJoin;
+import org.apache.beam.sdk.extensions.euphoria.core.client.operator.RightJoin;
+import 
org.apache.beam.sdk.extensions.euphoria.core.translate.provider.CompositeProvider;
+import 
org.apache.beam.sdk.extensions.euphoria.core.translate.provider.GenericTranslatorProvider;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TypeDescriptors;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+
+/** Unit tests of {@link BroadcastHashJoinTranslator}. */
+public class BroadcastHashJoinTranslatorTest {
+
+  @Rule public TestPipeline p = TestPipeline.create();
+
+  @Test
+  public void twoUsesOneViewTest() {
+
+    BroadcastHashJoinTranslator<?, ?, ?, ?> translatorUnderTest =
+        new BroadcastHashJoinTranslator<>();
+
+    EuphoriaOptions options = p.getOptions().as(EuphoriaOptions.class);
+    // Every join in this test will be translated as Broadcast
+    options.setTranslatorProvider(
+        CompositeProvider.of(
+            GenericTranslatorProvider.newBuilder()
+                .register(Join.class, (op) -> true, translatorUnderTest)
+                .build(),
+            GenericTranslatorProvider.createWithDefaultTranslators()));
+
+    // create input to be broadcast
+    PCollection<KV<Integer, String>> lengthStrings =
+        p.apply("names",
+            Create.of(KV.of(1, "one"), KV.of(2, "two"), KV.of(3, "three")))
+                .setTypeDescriptor(
+                    TypeDescriptors.kvs(TypeDescriptors.integers(), 
TypeDescriptors.strings()));
+
+    UnaryFunction<KV<Integer, String>, Integer> sharedKeyExtractor = 
KV::getKey;
+
+    // other datasets to be joined with
+    PCollection<String> letters =
+        p.apply("letters", Create.of("a", "b", "c", 
"d")).setTypeDescriptor(TypeDescriptors.strings());
+    PCollection<String> acronyms =
+        p.apply("acronyms", Create.of("B2K", "DIY", "FKA", 
"EOBD")).setTypeDescriptor(TypeDescriptors.strings());
+
+    PCollection<KV<Integer, String>> lettersJoined =
+        LeftJoin.named("join-letters-with-lengths")
+            .of(letters, lengthStrings)
+            .by(String::length, sharedKeyExtractor, TypeDescriptors.integers())
+            .using(
+                (letter, maybeLength, ctx) ->
+                    ctx.collect(letter + "-" + maybeLength.orElse(KV.of(-1, 
"null")).getValue()),
+                TypeDescriptors.strings())
+            .output();
+
+    PCollection<KV<Integer, String>> acronymsJoined =
+        RightJoin.named("join-acronyms-with-lengths")
+            .of(lengthStrings, acronyms)
+            .by(sharedKeyExtractor, String::length, TypeDescriptors.integers())
+            .using(
+                (maybeLength, acronym, ctx) ->
+                    ctx.collect(maybeLength.orElse(KV.of(-1, 
"null")).getValue() + "-" + acronym),
+                TypeDescriptors.strings())
+            .output();
+
+
+    PAssert.that(lettersJoined)
+        .containsInAnyOrder(
+            KV.of(1, "a-one"), KV.of(1, "b-one"), KV.of(1, "c-one"), KV.of(1, 
"d-one"));
+
+    PAssert.that(acronymsJoined)
+        .containsInAnyOrder(
+            KV.of(3, "three-B2K"),
+            KV.of(3, "three-DIY"),
+            KV.of(3, "three-FKA"),
+            KV.of(4, "null-EOBD"));
+
+    p.run();
+    Assert.assertEquals(1, translatorUnderTest.pViews.size());
+  }
+}


With regards,
Apache Git Services

Reply via email to