mosche commented on code in PR #22446:
URL: https://github.com/apache/beam/pull/22446#discussion_r943220261


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java:
##########
@@ -17,27 +17,132 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
-import java.util.concurrent.TimeoutException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Function;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
 import 
org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
-import org.apache.spark.sql.streaming.DataStreamWriter;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.spark.api.java.function.ForeachFunction;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
- * Subclass of {@link
- * 
org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext}
 that
- * address spark breaking changes.
+ * Base class that gives a context for {@link PTransform} translation: keeping 
track of the
+ * datasets, the {@link SparkSession}, the current transform being translated.
  */
-public class TranslationContext extends AbstractTranslationContext {
+@SuppressWarnings({
+  "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
+  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
+})
+public class TranslationContext {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(TranslationContext.class);
+
+  /** All the datasets of the DAG. */
+  private final Map<PValue, Dataset<?>> datasets;
+  /** datasets that are not used as input to other datasets (leaves of the 
DAG). */
+  private final Set<Dataset<?>> leaves;
+
+  private final SerializablePipelineOptions serializablePipelineOptions;
+
+  private final SparkSession sparkSession;
+
+  private final Map<PCollectionView<?>, Dataset<?>> broadcastDataSets;
+
+  private final Map<Coder<?>, ExpressionEncoder<?>> encoders;
 
   public TranslationContext(SparkStructuredStreamingPipelineOptions options) {
-    super(options);
+    this.sparkSession = SparkSessionFactory.getOrCreateSession(options);
+    this.serializablePipelineOptions = new 
SerializablePipelineOptions(options);
+    this.datasets = new HashMap<>();
+    this.leaves = new HashSet<>();
+    this.broadcastDataSets = new HashMap<>();
+    this.encoders = new HashMap<>();
+  }
+
+  public SparkSession getSparkSession() {
+    return sparkSession;
+  }
+
+  public SerializablePipelineOptions getSerializableOptions() {
+    return serializablePipelineOptions;
+  }
+
+  // 
--------------------------------------------------------------------------------------------

Review Comment:
   Kept those from the original code, but agree



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to