This is an automated email from the ASF dual-hosted git repository.

echauchot pushed a commit to branch spark-runner_structured-streaming
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 92c94b123c11eab6e4cfd2441a64463253f2afa2
Author: Etienne Chauchot <[email protected]>
AuthorDate: Wed Jan 2 16:08:31 2019 +0100

    Refactor DatasetSource fields
---
 .../translation/batch/DatasetSourceBatch.java      | 40 ++++++++++++----------
 1 file changed, 22 insertions(+), 18 deletions(-)

diff --git 
a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DatasetSourceBatch.java
 
b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DatasetSourceBatch.java
index 331e397..e19bbdb 100644
--- 
a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DatasetSourceBatch.java
+++ 
b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DatasetSourceBatch.java
@@ -49,10 +49,6 @@ public class DatasetSourceBatch<T> implements DataSourceV2, 
ReadSupport {
   static final String BEAM_SOURCE_OPTION = "beam-source";
   static final String DEFAULT_PARALLELISM = "default-parallelism";
   static final String PIPELINE_OPTIONS = "pipeline-options";
-  private int numPartitions;
-  private Long bundleSize;
-  private BoundedSource<T> source;
-  private SparkPipelineOptions sparkPipelineOptions;
 
 
   @SuppressWarnings("unchecked")
@@ -61,31 +57,39 @@ public class DatasetSourceBatch<T> implements DataSourceV2, 
ReadSupport {
     if (!options.get(BEAM_SOURCE_OPTION).isPresent()){
       throw new RuntimeException("Beam source was not set in DataSource 
options");
     }
-    this.source = Base64Serializer
+    BoundedSource<T> source = Base64Serializer
         .deserializeUnchecked(options.get(BEAM_SOURCE_OPTION).get(), 
BoundedSource.class);
 
     if (!options.get(DEFAULT_PARALLELISM).isPresent()){
       throw new RuntimeException("Spark default parallelism was not set in 
DataSource options");
     }
-    if (!options.get(BEAM_SOURCE_OPTION).isPresent()){
-      throw new RuntimeException("Beam source was not set in DataSource 
options");
-    }
-    this.numPartitions = 
Integer.valueOf(options.get(DEFAULT_PARALLELISM).get());
-    checkArgument(this.numPartitions > 0, "Number of partitions must be 
greater than zero.");
+    int numPartitions = 
Integer.valueOf(options.get(DEFAULT_PARALLELISM).get());
+    checkArgument(numPartitions > 0, "Number of partitions must be greater 
than zero.");
+
     if (!options.get(PIPELINE_OPTIONS).isPresent()){
       throw new RuntimeException("Beam pipelineOptions were not set in 
DataSource options");
     }
-    this.sparkPipelineOptions = SerializablePipelineOptions
+    SparkPipelineOptions sparkPipelineOptions = SerializablePipelineOptions
         
.deserializeFromJson(options.get(PIPELINE_OPTIONS).get()).as(SparkPipelineOptions.class);
-    this.bundleSize = sparkPipelineOptions.getBundleSize();
-    return new DatasetReader();  }
+    return new DatasetReader(numPartitions, source, sparkPipelineOptions);
+  }
 
   /** This class can be mapped to Beam {@link BoundedSource}. */
   private class DatasetReader implements DataSourceReader {
 
+    private int numPartitions;
+    private BoundedSource<T> source;
+    private SparkPipelineOptions sparkPipelineOptions;
     private Optional<StructType> schema;
     private String checkpointLocation;
 
+    private DatasetReader(int numPartitions, BoundedSource<T> source,
+        SparkPipelineOptions sparkPipelineOptions) {
+      this.numPartitions = numPartitions;
+      this.source = source;
+      this.sparkPipelineOptions = sparkPipelineOptions;
+    }
+
     @Override
     public StructType readSchema() {
       return new StructType();
@@ -97,11 +101,11 @@ public class DatasetSourceBatch<T> implements 
DataSourceV2, ReadSupport {
       long desiredSizeBytes;
       try {
         desiredSizeBytes =
-            (bundleSize == null)
+            (sparkPipelineOptions.getBundleSize() == null)
                 ? source.getEstimatedSizeBytes(sparkPipelineOptions) / 
numPartitions
-                : bundleSize;
-        List<? extends BoundedSource<T>> sources = 
source.split(desiredSizeBytes, sparkPipelineOptions);
-        for (BoundedSource<T> source : sources) {
+                : sparkPipelineOptions.getBundleSize();
+        List<? extends BoundedSource<T>> splits = 
source.split(desiredSizeBytes, sparkPipelineOptions);
+        for (BoundedSource<T> split : splits) {
           result.add(
               new InputPartition<InternalRow>() {
 
@@ -109,7 +113,7 @@ public class DatasetSourceBatch<T> implements DataSourceV2, 
ReadSupport {
                 public InputPartitionReader<InternalRow> 
createPartitionReader() {
                   BoundedReader<T> reader = null;
                   try {
-                    reader = source.createReader(sparkPipelineOptions);
+                    reader = split.createReader(sparkPipelineOptions);
                   } catch (IOException e) {
                     throw new RuntimeException(
                         "Error creating BoundedReader " + 
reader.getClass().getCanonicalName(), e);

Reply via email to