Github user jose-torres commented on a diff in the pull request:
https://github.com/apache/spark/pull/20552#discussion_r167126838
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
---
@@ -17,52 +17,119 @@
package org.apache.spark.sql.execution.streaming
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream,
ObjectInputStream, ObjectOutputStream}
+
import org.apache.spark.TaskContext
-import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter}
-import org.apache.spark.sql.catalyst.encoders.encoderFor
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{encoderFor,
ExpressionEncoder}
+import
org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
+import org.apache.spark.sql.sources.v2.{DataSourceOptions,
StreamWriteSupport}
+import org.apache.spark.sql.sources.v2.writer._
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.StructType
-/**
- * A [[Sink]] that forwards all data into [[ForeachWriter]] according to
the contract defined by
- * [[ForeachWriter]].
- *
- * @param writer The [[ForeachWriter]] to process all data.
- * @tparam T The expected type of the sink.
- */
-class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with
Serializable {
-
- override def addBatch(batchId: Long, data: DataFrame): Unit = {
- // This logic should've been as simple as:
- // ```
- // data.as[T].foreachPartition { iter => ... }
- // ```
- //
- // Unfortunately, doing that would just break the incremental planing.
The reason is,
- // `Dataset.foreachPartition()` would further call `Dataset.rdd()`,
but `Dataset.rdd()` will
- // create a new plan. Because StreamExecution uses the existing plan
to collect metrics and
- // update watermark, we should never create a new plan. Otherwise,
metrics and watermark are
- // updated in the new plan, and StreamExecution cannot retrieval them.
- //
- // Hence, we need to manually convert internal rows to objects using
encoder.
+
+case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T])
extends StreamWriteSupport {
+ override def createStreamWriter(
+ queryId: String,
+ schema: StructType,
+ mode: OutputMode,
+ options: DataSourceOptions): StreamWriter = {
val encoder = encoderFor[T].resolveAndBind(
- data.logicalPlan.output,
- data.sparkSession.sessionState.analyzer)
- data.queryExecution.toRdd.foreachPartition { iter =>
- if (writer.open(TaskContext.getPartitionId(), batchId)) {
- try {
- while (iter.hasNext) {
- writer.process(encoder.fromRow(iter.next()))
- }
- } catch {
- case e: Throwable =>
- writer.close(e)
- throw e
- }
- writer.close(null)
- } else {
- writer.close(null)
+ schema.toAttributes,
+ SparkSession.getActiveSession.get.sessionState.analyzer)
+ ForeachInternalWriter(writer, encoder)
+ }
+}
+
+case class ForeachInternalWriter[T: Encoder](
+ writer: ForeachWriter[T], encoder: ExpressionEncoder[T])
+ extends StreamWriter with SupportsWriteInternalRow {
+ override def commit(epochId: Long, messages:
Array[WriterCommitMessage]): Unit = {}
+ override def abort(epochId: Long, messages: Array[WriterCommitMessage]):
Unit = {}
+
+ override def createInternalRowWriterFactory():
DataWriterFactory[InternalRow] = {
+ ForeachWriterFactory(writer, encoder)
+ }
+}
+
+case class ForeachWriterFactory[T: Encoder](writer: ForeachWriter[T],
encoder: ExpressionEncoder[T])
+ extends DataWriterFactory[InternalRow] {
+ override def createDataWriter(partitionId: Int, attemptNumber: Int):
ForeachDataWriter[T] = {
+ new ForeachDataWriter(writer, encoder, partitionId)
+ }
+}
+
+class ForeachDataWriter[T : Encoder](
+ private var writer: ForeachWriter[T], encoder: ExpressionEncoder[T],
partitionId: Int)
+ extends DataWriter[InternalRow] {
+ private val initialEpochId: Long = {
+ // Start with the microbatch ID. If it's not there, we're in
continuous execution,
+ // so get the start epoch.
+ // This ID will be incremented as commits happen.
+ TaskContext.get().getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)
match {
+ case null =>
TaskContext.get().getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
+ case batch => batch.toLong
+ }
+ }
+ private var currentEpochId = initialEpochId
+
+ // The lifecycle of the ForeachWriter is incompatible with the lifecycle
of DataSourceV2 writers.
+ // Unfortunately, we cannot migrate ForeachWriter, as its
implementations live in user code. So
+ // we need a small state machine to shim between them.
+ // * CLOSED means close() has been called.
+ // * OPENED
+ private object WriterState extends Enumeration {
+ type WriterState = Value
+ val CLOSED, OPENED, OPENED_SKIP_PROCESSING = Value
+ }
+ import WriterState._
+
+ private var state = CLOSED
+
+ private def openAndSetState(epochId: Long) = {
+ // Create a new writer by roundtripping through the serialization for
compatibility.
+ // In the old API, a writer instantiation would never get reused.
+ val byteStream = new ByteArrayOutputStream()
--- End diff --
You're right; this suggestion is what we really want.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]