mosche commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1025017277


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static 
org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static 
org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to 
their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms 
replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark 
correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple 
times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is 
required.
+ *   <li>Identify datasets that are repeatedly used as input and should be 
cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} 
that is {@link
+ *       #getTransformTranslator known} and {@link 
TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be 
expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends 
Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = 
LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // 
--------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // 
--------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions 
options) {
     
pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) 
and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode 
for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions 
options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions 
options) {
+    StreamingModeDetector detector = new 
StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} 
if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends 
PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> 
getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark 
SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts 
of the pipeline. For
+   * example, in order to use a side-input {@link 
org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to 
be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} 
that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions 
options) {
+    LOG.debug("starting translation of the pipeline using {}", 
getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, 
dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);
+  }
+
+  /**
+   * The correspondence of a {@link PCollection} as result of translating a 
{@link PTransform}
+   * including additional metadata (such as name and dependents).
+   */
+  private static final class TranslationResult<T> implements 
EvaluationContext.NamedDataset<T> {
+    private final String name;
+    private @Nullable Dataset<WindowedValue<T>> dataset = null;
+    private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
+
+    private TranslationResult(PCollection<?> pCol) {
+      this.name = pCol.getName();
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public @Nullable Dataset<WindowedValue<T>> dataset() {
+      return dataset;
     }
   }
 
-  /** The translation mode of the Beam Pipeline. */
-  private enum TranslationMode {
+  /** Shared, mutable state during the translation of a pipeline and omitted 
afterwards. */
+  interface TranslationState extends EncoderProvider {
+    <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection);
+
+    <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean 
noCache);
 
-    /** Uses the batch mode. */
-    BATCH,
+    default <T> void putDataset(PCollection<T> pCollection, 
Dataset<WindowedValue<T>> dataset) {
+      putDataset(pCollection, dataset, false);
+    }
 
-    /** Uses the streaming mode. */
-    STREAMING
+    SerializablePipelineOptions getSerializableOptions();
+
+    SparkSession getSparkSession();
   }
 
-  /** Traverses the Pipeline to determine the {@link TranslationMode} for this 
pipeline. */
-  private static class TranslationModeDetector extends 
Pipeline.PipelineVisitor.Defaults {
-    private static final Logger LOG = 
LoggerFactory.getLogger(TranslationModeDetector.class);
+  /**
+   * {@link PTransformVisitor} that translates supported {@link PTransform 
PTransforms} into their
+   * Spark correspondence.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts 
of the pipeline. For
+   * example, in order to use a side-input {@link 
org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to 
be collected and
+   * broadcasted.
+   */
+  private class TranslatingVisitor extends PTransformVisitor implements 
TranslationState {
+    private final Map<PCollection<?>, TranslationResult<?>> translationResults;
+    private final Map<Coder<?>, ExpressionEncoder<?>> encoders;
+    private final SparkSession sparkSession;
+    private final SerializablePipelineOptions serializableOptions;
+    private final StorageLevel storageLevel;
+
+    private final Set<TranslationResult<?>> leaves;
+
+    public TranslatingVisitor(
+        SparkSession sparkSession,
+        SparkCommonPipelineOptions options,
+        Map<PCollection<?>, TranslationResult<?>> translationResults) {
+      this.sparkSession = sparkSession;
+      this.translationResults = translationResults;
+      this.serializableOptions = new SerializablePipelineOptions(options);
+      this.storageLevel = StorageLevel.fromString(options.getStorageLevel());
+      this.encoders = new HashMap<>();
+      this.leaves = new HashSet<>();
+    }
 
-    private TranslationMode translationMode;
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+
+      AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
+          (AppliedPTransform) node.toAppliedPTransform(getPipeline());
+      try {
+        LOG.info(
+            "Translating {}: {}",
+            node.isCompositeNode() ? "composite" : "primitive",
+            node.getFullName());
+        translator.translate(transform, appliedTransform, this);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
 
-    TranslationModeDetector(TranslationMode defaultMode) {
-      this.translationMode = defaultMode;
+    @Override
+    public <T> Encoder<T> encoderOf(Coder<T> coder, Factory<T> factory) {
+      return (Encoder<T>) encoders.computeIfAbsent(coder, (Factory) factory);
     }
 
-    TranslationModeDetector() {
-      this(TranslationMode.BATCH);
+    private <T> TranslationResult<T> getResult(PCollection<T> pCollection) {
+      return (TranslationResult<T>) 
checkStateNotNull(translationResults.get(pCollection));
     }
 
-    TranslationMode getTranslationMode() {
-      return translationMode;
+    @Override
+    public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> 
pCollection) {
+      return checkStateNotNull(getResult(pCollection).dataset);
     }
 
     @Override
-    public void visitValue(PValue value, TransformHierarchy.Node producer) {
-      if (translationMode.equals(TranslationMode.BATCH)) {
-        if (value instanceof PCollection
-            && ((PCollection) value).isBounded() == 
PCollection.IsBounded.UNBOUNDED) {
-          LOG.info(
-              "Found unbounded PCollection {}. Switching to streaming 
execution.", value.getName());
-          translationMode = TranslationMode.STREAMING;
+    public <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean 
noCache) {
+      TranslationResult<T> result = getResult(pCollection);
+      if (!noCache && result.dependentTransforms.size() > 1) {
+        LOG.info("Dataset {} will be cached.", result.name);
+        result.dataset = dataset.persist(storageLevel); // use NONE to disable
+      } else {
+        result.dataset = dataset;
+        if (result.dependentTransforms.isEmpty()) {
+          leaves.add(result);
         }
       }
     }
-  }
 
-  // 
--------------------------------------------------------------------------------------------
-  //  Pipeline utility methods
-  // 
--------------------------------------------------------------------------------------------
+    @Override
+    public SerializablePipelineOptions getSerializableOptions() {
+      return serializableOptions;
+    }
 
-  /** Get a {@link TransformTranslator} for the given {@link 
TransformHierarchy.Node}. */
-  protected abstract @Nullable <
-          InT extends PInput, OutT extends POutput, TransformT extends 
PTransform<InT, OutT>>
-      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(
-          @Nullable TransformT transform);
-
-  /** Apply the given TransformTranslator to the given node. */
-  private <InT extends PInput, OutT extends POutput, TransformT extends 
PTransform<InT, OutT>>
-      void applyTransformTranslator(
-          TransformHierarchy.Node node,
-          TransformT transform,
-          TransformTranslator<InT, OutT, TransformT> transformTranslator) {
-    // create the applied PTransform on the translationContext
-    AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
-        (AppliedPTransform) node.toAppliedPTransform(getPipeline());
-    try {
-      transformTranslator.translate(transform, appliedTransform, 
translationContext);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
+    @Override
+    public SparkSession getSparkSession() {
+      return sparkSession;
     }
   }
 
-  // 
--------------------------------------------------------------------------------------------
-  //  Pipeline visitor entry point
-  // 
--------------------------------------------------------------------------------------------
-
   /**
-   * Translates the pipeline by passing this class as a visitor.
+   * {@link PTransformVisitor} that analyses dependencies of supported {@link 
PTransform
+   * PTransforms} to help identify cache candidates.
    *
-   * @param pipeline The pipeline to be translated
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses 
unsupported features.
    */
-  public void translate(Pipeline pipeline) {
-    LOG.debug("starting translation of the pipeline using {}", 
getClass().getName());
-    pipeline.traverseTopologically(this);
+  private class DependencyVisitor extends PTransformVisitor {
+    private final Map<PCollection<?>, TranslationResult<?>> results = new 
HashMap<>();
+
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+      for (PCollection<?> pOut : node.getOutputs().values()) {
+        results.put(pOut, new TranslationResult<>(pOut));
+        for (Map.Entry<TupleTag<?>, PCollection<?>> entry : 
node.getInputs().entrySet()) {
+          TranslationResult<?> input = 
checkStateNotNull(results.get(entry.getValue()));
+          input.dependentTransforms.add(transform);
+        }
+      }
+    }
   }
 
-  // 
--------------------------------------------------------------------------------------------
-  //  Pipeline Visitor Methods
-  // 
--------------------------------------------------------------------------------------------
+  /**
+   * An abstract {@link PipelineVisitor} that visits all translatable {@link 
PTransform} pipeline
+   * nodes of a pipeline with the respective {@link TransformTranslator}.
+   *
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses 
unsupported features.
+   */
+  private abstract class PTransformVisitor extends PipelineVisitor.Defaults {
 
-  @Override
-  public CompositeBehavior enterCompositeTransform(TransformHierarchy.Node 
node) {
-    PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) 
node.getTransform();
-    TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> 
transformTranslator =
-        getTransformTranslator(transform);
+    /** Visit the {@link PTransform} with its respective {@link 
TransformTranslator}. */
+    abstract <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator);
 
-    if (transformTranslator != null) {
-      LOG.info("Translating composite: {}", node.getFullName());
-      applyTransformTranslator(node, transform, transformTranslator);
-      return CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
-    } else {
-      return CompositeBehavior.ENTER_TRANSFORM;
+    @Override
+    public final CompositeBehavior enterCompositeTransform(Node node) {
+      PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) 
node.getTransform();
+      TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> 
translator =
+          getTranslator(transform);
+      if (transform != null && translator != null) {
+        visit(node, transform, translator);
+        return DO_NOT_ENTER_TRANSFORM;
+      } else {
+        return ENTER_TRANSFORM;
+      }
+    }
+
+    @Override
+    public final void visitPrimitiveTransform(Node node) {
+      PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) 
node.getTransform();
+      if (transform == null || 
transform.getClass().equals(View.CreatePCollectionView.class)) {
+        return; // ignore, nothing to be translated here

Review Comment:
   yes, PCollectionView view translation just stored the same Spark dataset 
(reference!) again for a different PTransform. that's obviously problematic for 
caching as we're not gathering metadata on that dataset in a single place. 
also, beam runner guidelines discourage translation of PCollectionView, they 
are just there for legacy reasons.
   
   > ignore, nothing to be translated here, views are handled on the consumer 
side
   
   is this sufficient as a comment, or what would you suggest?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to