This is an automated email from the ASF dual-hosted git repository.

xinyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 2b4aaa0  [BEAM-12823] TestStream Support in Samza Portable Runner 
(#15421)
2b4aaa0 is described below

commit 2b4aaa0c2101163626ab6aa223d35b58cf4bc5b6
Author: Ke Wu <[email protected]>
AuthorDate: Thu Sep 2 16:18:11 2021 -0700

    [BEAM-12823] TestStream Support in Samza Portable Runner (#15421)
---
 .../core/construction/TestStreamTranslation.java   | 11 ++-
 runners/samza/job-server/build.gradle              | 18 ++++-
 .../SamzaPortablePipelineTranslator.java           |  1 +
 .../translation/SamzaTestStreamSystemFactory.java  | 23 +++---
 .../translation/SamzaTestStreamTranslator.java     | 90 +++++++++++++++++-----
 5 files changed, 106 insertions(+), 37 deletions(-)

diff --git 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TestStreamTranslation.java
 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TestStreamTranslation.java
index 82d8810..7060ca0 100644
--- 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TestStreamTranslation.java
+++ 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TestStreamTranslation.java
@@ -49,13 +49,18 @@ import org.joda.time.Instant;
 })
 public class TestStreamTranslation {
 
-  public static TestStream<?> testStreamFromProtoPayload(
+  public static <T> TestStream<T> testStreamFromProtoPayload(
       RunnerApi.TestStreamPayload testStreamPayload, RehydratedComponents 
components)
       throws IOException {
 
-    Coder<Object> coder = (Coder<Object>) 
components.getCoder(testStreamPayload.getCoderId());
+    Coder<T> coder = (Coder<T>) 
components.getCoder(testStreamPayload.getCoderId());
 
-    List<TestStream.Event<Object>> events = new ArrayList<>();
+    return testStreamFromProtoPayload(testStreamPayload, coder);
+  }
+
+  public static <T> TestStream<T> testStreamFromProtoPayload(
+      RunnerApi.TestStreamPayload testStreamPayload, Coder<T> coder) throws 
IOException {
+    List<TestStream.Event<T>> events = new ArrayList<>();
 
     for (RunnerApi.TestStreamPayload.Event event : 
testStreamPayload.getEventsList()) {
       events.add(eventFromProto(event, coder));
diff --git a/runners/samza/job-server/build.gradle 
b/runners/samza/job-server/build.gradle
index 4c03123..3fe2982 100644
--- a/runners/samza/job-server/build.gradle
+++ b/runners/samza/job-server/build.gradle
@@ -91,7 +91,7 @@ createPortableValidatesRunnerTask(
             excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer'
             excludeCategories 
'org.apache.beam.sdk.testing.UsesOrderedListState'
             excludeCategories 
'org.apache.beam.sdk.testing.UsesBoundedSplittableParDo'
-            excludeCategories 'org.apache.beam.sdk.testing.UsesTestStream'
+            excludeCategories 
'org.apache.beam.sdk.testing.UsesTestStreamWithProcessingTime'
             excludeCategories 
'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
             excludeCategories 
'org.apache.beam.sdk.testing.UsesSplittableParDoWithWindowedSideInputs'
             excludeCategories 'org.apache.beam.sdk.testing.UsesLoopingTimer'
@@ -102,6 +102,22 @@ createPortableValidatesRunnerTask(
             excludeTestsMatching 
"org.apache.beam.sdk.transforms.FlattenTest.testEmptyFlattenAsSideInput"
             excludeTestsMatching 
"org.apache.beam.sdk.transforms.FlattenTest.testFlattenPCollectionsEmptyThenParDo"
             excludeTestsMatching 
"org.apache.beam.sdk.transforms.FlattenTest.testFlattenPCollectionsEmpty"
+            excludeTestsMatching 
'org.apache.beam.sdk.transforms.ViewTest.testWindowedSideInputNotPresent'
+            excludeTestsMatching 
'org.apache.beam.sdk.transforms.ParDoTest$TimerTests.testEventTimeTimerAlignUnbounded'
+            excludeTestsMatching 
'org.apache.beam.sdk.transforms.ParDoTest$TimerTests.testEventTimeTimerUnbounded'
+            excludeTestsMatching 
'org.apache.beam.sdk.transforms.ParDoTest$TimerTests.testEventTimeTimerAlignAfterGcTimeUnbounded'
+            // TODO(BEAM-10025)
+            excludeTestsMatching 
'org.apache.beam.sdk.transforms.ParDoTest$TimerTests.testOutputTimestampDefaultUnbounded'
+            // TODO(BEAM-11479)
+            excludeTestsMatching 
'org.apache.beam.sdk.transforms.ParDoTest$TimerTests.testOutputTimestamp'
+            // TODO(BEAM-12035)
+            excludeTestsMatching 
'org.apache.beam.sdk.testing.TestStreamTest.testFirstElementLate'
+            // TODO(BEAM-12036)
+            excludeTestsMatching 
'org.apache.beam.sdk.testing.TestStreamTest.testLateDataAccumulating'
+            // TODO(BEAM-12821)
+            excludeTestsMatching 
'org.apache.beam.sdk.testing.TestStreamTest.testMultiStage'
+            // TODO(BEAM-12822)
+            excludeTestsMatching 
'org.apache.beam.sdk.testing.TestStreamTest.testMultipleStreams'
         }
 )
 
diff --git 
a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPortablePipelineTranslator.java
 
b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPortablePipelineTranslator.java
index 407cc5d..9158cd4 100644
--- 
a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPortablePipelineTranslator.java
+++ 
b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPortablePipelineTranslator.java
@@ -106,6 +106,7 @@ public class SamzaPortablePipelineTranslator {
           .put(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN, new 
GroupByKeyTranslator<>())
           .put(PTransformTranslation.FLATTEN_TRANSFORM_URN, new 
FlattenPCollectionsTranslator<>())
           .put(PTransformTranslation.IMPULSE_TRANSFORM_URN, new 
ImpulseTranslator())
+          .put(PTransformTranslation.TEST_STREAM_TRANSFORM_URN, new 
SamzaTestStreamTranslator<>())
           .put(ExecutableStage.URN, new ParDoBoundMultiTranslator<>())
           .build();
     }
diff --git 
a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaTestStreamSystemFactory.java
 
b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaTestStreamSystemFactory.java
index 96dc577..570be61 100644
--- 
a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaTestStreamSystemFactory.java
+++ 
b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaTestStreamSystemFactory.java
@@ -35,6 +35,7 @@ import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Immutabl
 import org.apache.samza.Partition;
 import org.apache.samza.SamzaException;
 import org.apache.samza.config.Config;
+import org.apache.samza.config.SystemConfig;
 import org.apache.samza.metrics.MetricsRegistry;
 import org.apache.samza.system.IncomingMessageEnvelope;
 import org.apache.samza.system.SystemAdmin;
@@ -55,9 +56,9 @@ import org.apache.samza.system.SystemStreamPartition;
 public class SamzaTestStreamSystemFactory implements SystemFactory {
   @Override
   public SystemConsumer getConsumer(String systemName, Config config, 
MetricsRegistry registry) {
-    final String streamPrefix = "systems." + systemName;
-    final Config scopedConfig = config.subset(streamPrefix + ".", true);
-    return new SmazaTestStreamSystemConsumer<>(getTestStream(scopedConfig));
+    final String streamPrefix = String.format(SystemConfig.SYSTEM_ID_PREFIX, 
systemName);
+    final Config scopedConfig = config.subset(streamPrefix, true);
+    return new SamzaTestStreamSystemConsumer<>(getTestStream(scopedConfig));
   }
 
   @Override
@@ -75,14 +76,13 @@ public class SamzaTestStreamSystemFactory implements 
SystemFactory {
     @SuppressWarnings("unchecked")
     final SerializableFunction<String, TestStream<T>> testStreamDecoder =
         Base64Serializer.deserializeUnchecked(
-            config.get("testStreamDecoder"), SerializableFunction.class);
-    final TestStream<T> testStream = 
testStreamDecoder.apply(config.get("encodedTestStream"));
-    return testStream;
+            config.get(SamzaTestStreamTranslator.TEST_STREAM_DECODER), 
SerializableFunction.class);
+    return 
testStreamDecoder.apply(config.get(SamzaTestStreamTranslator.ENCODED_TEST_STREAM));
   }
 
   private static final String DUMMY_OFFSET = "0";
 
-  /** System admin for SmazaTestStreamSystem. */
+  /** System admin for SamzaTestStreamSystem. */
   public static class SamzaTestStreamSystemAdmin implements SystemAdmin {
     @Override
     public Map<SystemStreamPartition, String> getOffsetsAfter(
@@ -115,11 +115,11 @@ public class SamzaTestStreamSystemFactory implements 
SystemFactory {
     }
   }
 
-  /** System consumer for SmazaTestStreamSystem. */
-  public static class SmazaTestStreamSystemConsumer<T> implements 
SystemConsumer {
+  /** System consumer for SamzaTestStreamSystem. */
+  public static class SamzaTestStreamSystemConsumer<T> implements 
SystemConsumer {
     TestStream<T> testStream;
 
-    public SmazaTestStreamSystemConsumer(TestStream<T> testStream) {
+    public SamzaTestStreamSystemConsumer(TestStream<T> testStream) {
       this.testStream = testStream;
     }
 
@@ -134,8 +134,7 @@ public class SamzaTestStreamSystemFactory implements 
SystemFactory {
 
     @Override
     public Map<SystemStreamPartition, List<IncomingMessageEnvelope>> poll(
-        Set<SystemStreamPartition> systemStreamPartitions, long timeout)
-        throws InterruptedException {
+        Set<SystemStreamPartition> systemStreamPartitions, long timeout) {
       SystemStreamPartition ssp = systemStreamPartitions.iterator().next();
       ArrayList<IncomingMessageEnvelope> messages = new ArrayList<>();
 
diff --git 
a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaTestStreamTranslator.java
 
b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaTestStreamTranslator.java
index ef38a79..e50dc2c 100644
--- 
a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaTestStreamTranslator.java
+++ 
b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaTestStreamTranslator.java
@@ -17,11 +17,16 @@
  */
 package org.apache.beam.runners.samza.translation;
 
+import java.io.IOException;
 import java.util.Map;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.runners.core.construction.RehydratedComponents;
+import org.apache.beam.runners.core.construction.TestStreamTranslation;
 import org.apache.beam.runners.core.construction.graph.PipelineNode;
 import org.apache.beam.runners.core.construction.graph.QueryablePipeline;
 import org.apache.beam.runners.core.serialization.Base64Serializer;
 import org.apache.beam.runners.samza.runtime.OpMessage;
+import org.apache.beam.runners.samza.util.SamzaPipelineTranslatorUtils;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.CoderException;
 import org.apache.beam.sdk.runners.TransformHierarchy;
@@ -29,8 +34,8 @@ import org.apache.beam.sdk.testing.TestStream;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.util.CoderUtils;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
-import org.apache.samza.SamzaException;
 import org.apache.samza.operators.KV;
 import org.apache.samza.serializers.KVSerde;
 import org.apache.samza.serializers.NoOpSerde;
@@ -40,11 +45,12 @@ import 
org.apache.samza.system.descriptors.GenericSystemDescriptor;
 
 /**
  * Translate {@link org.apache.beam.sdk.testing.TestStream} to a samza message 
stream produced by
- * {@link
- * 
org.apache.beam.runners.samza.translation.SamzaTestStreamSystemFactory.SmazaTestStreamSystemConsumer}.
+ * {@link SamzaTestStreamSystemFactory.SamzaTestStreamSystemConsumer}.
  */
 @SuppressWarnings({"rawtypes"})
 public class SamzaTestStreamTranslator<T> implements 
TransformTranslator<TestStream<T>> {
+  public static final String ENCODED_TEST_STREAM = "encodedTestStream";
+  public static final String TEST_STREAM_DECODER = "testStreamDecoder";
 
   @Override
   public void translate(
@@ -53,15 +59,13 @@ public class SamzaTestStreamTranslator<T> implements 
TransformTranslator<TestStr
     final String outputId = ctx.getIdForPValue(output);
     final Coder<T> valueCoder = testStream.getValueCoder();
     final TestStream.TestStreamCoder<T> testStreamCoder = 
TestStream.TestStreamCoder.of(valueCoder);
-    final GenericSystemDescriptor systemDescriptor =
-        new GenericSystemDescriptor(outputId, 
SamzaTestStreamSystemFactory.class.getName());
 
     // encode testStream as a string
     final String encodedTestStream;
     try {
       encodedTestStream = CoderUtils.encodeToBase64(testStreamCoder, 
testStream);
     } catch (CoderException e) {
-      throw new SamzaException("Could not encode TestStream.", e);
+      throw new RuntimeException("Could not encode TestStream.", e);
     }
 
     // the decoder for encodedTestStream
@@ -70,31 +74,75 @@ public class SamzaTestStreamTranslator<T> implements 
TransformTranslator<TestStr
           try {
             return 
CoderUtils.decodeFromBase64(TestStream.TestStreamCoder.of(valueCoder), string);
           } catch (CoderException e) {
-            throw new SamzaException("Could not decode TestStream.", e);
+            throw new RuntimeException("Could not decode TestStream.", e);
           }
         };
 
+    ctx.registerInputMessageStream(
+        output, createInputDescriptor(outputId, encodedTestStream, 
testStreamDecoder));
+  }
+
+  @Override
+  public void translatePortable(
+      PipelineNode.PTransformNode transform,
+      QueryablePipeline pipeline,
+      PortableTranslationContext ctx) {
+    final ByteString bytes = transform.getTransform().getSpec().getPayload();
+    final SerializableFunction<String, TestStream<T>> testStreamDecoder =
+        createTestStreamDecoder(pipeline.getComponents(), bytes);
+
+    final String outputId = ctx.getOutputId(transform);
+    final String escapedOutputId = 
SamzaPipelineTranslatorUtils.escape(outputId);
+
+    ctx.registerInputMessageStream(
+        outputId,
+        createInputDescriptor(
+            escapedOutputId, Base64Serializer.serializeUnchecked(bytes), 
testStreamDecoder));
+  }
+
+  @SuppressWarnings("unchecked")
+  private static <T> GenericInputDescriptor<KV<?, OpMessage<T>>> 
createInputDescriptor(
+      String id,
+      String encodedTestStream,
+      SerializableFunction<String, TestStream<T>> testStreamDecoder) {
     final Map<String, String> systemConfig =
         ImmutableMap.of(
-            "encodedTestStream",
+            ENCODED_TEST_STREAM,
             encodedTestStream,
-            "testStreamDecoder",
+            TEST_STREAM_DECODER,
             Base64Serializer.serializeUnchecked(testStreamDecoder));
-    systemDescriptor.withSystemConfigs(systemConfig);
+    final GenericSystemDescriptor systemDescriptor =
+        new GenericSystemDescriptor(id, 
SamzaTestStreamSystemFactory.class.getName())
+            .withSystemConfigs(systemConfig);
 
     // The KvCoder is needed here for Samza not to crop the key.
-    final Serde<KV<?, OpMessage<byte[]>>> kvSerde = KVSerde.of(new 
NoOpSerde(), new NoOpSerde<>());
-    final GenericInputDescriptor<KV<?, OpMessage<byte[]>>> inputDescriptor =
-        systemDescriptor.getInputDescriptor(outputId, kvSerde);
-
-    ctx.registerInputMessageStream(output, inputDescriptor);
+    final Serde<KV<?, OpMessage<T>>> kvSerde = KVSerde.of(new NoOpSerde(), new 
NoOpSerde<>());
+    return systemDescriptor.getInputDescriptor(id, kvSerde);
   }
 
-  @Override
-  public void translatePortable(
-      PipelineNode.PTransformNode transform,
-      QueryablePipeline pipeline,
-      PortableTranslationContext ctx) {
-    throw new SamzaException("TestStream is not supported in portable by Samza 
runner");
+  @SuppressWarnings("unchecked")
+  private static <T> SerializableFunction<String, TestStream<T>> 
createTestStreamDecoder(
+      RunnerApi.Components components, ByteString payload) {
+    Coder<T> coder;
+    try {
+      coder =
+          (Coder<T>)
+              RehydratedComponents.forComponents(components)
+                  
.getCoder(RunnerApi.TestStreamPayload.parseFrom(payload).getCoderId());
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    }
+
+    // the decoder for encodedTestStream
+    return encodedTestStream -> {
+      try {
+        return TestStreamTranslation.testStreamFromProtoPayload(
+            RunnerApi.TestStreamPayload.parseFrom(
+                Base64Serializer.deserializeUnchecked(encodedTestStream, 
ByteString.class)),
+            coder);
+      } catch (IOException e) {
+        throw new RuntimeException("Could not decode TestStream.", e);
+      }
+    };
   }
 }

Reply via email to