Github user jose-torres commented on a diff in the pull request: https://github.com/apache/spark/pull/20552#discussion_r167126838 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala --- @@ -17,52 +17,119 @@ package org.apache.spark.sql.execution.streaming +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} + import org.apache.spark.TaskContext -import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter} -import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType -/** - * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by - * [[ForeachWriter]]. - * - * @param writer The [[ForeachWriter]] to process all data. - * @tparam T The expected type of the sink. - */ -class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable { - - override def addBatch(batchId: Long, data: DataFrame): Unit = { - // This logic should've been as simple as: - // ``` - // data.as[T].foreachPartition { iter => ... } - // ``` - // - // Unfortunately, doing that would just break the incremental planing. The reason is, - // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` will - // create a new plan. Because StreamExecution uses the existing plan to collect metrics and - // update watermark, we should never create a new plan. Otherwise, metrics and watermark are - // updated in the new plan, and StreamExecution cannot retrieval them. - // - // Hence, we need to manually convert internal rows to objects using encoder. + +case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport { + override def createStreamWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceOptions): StreamWriter = { val encoder = encoderFor[T].resolveAndBind( - data.logicalPlan.output, - data.sparkSession.sessionState.analyzer) - data.queryExecution.toRdd.foreachPartition { iter => - if (writer.open(TaskContext.getPartitionId(), batchId)) { - try { - while (iter.hasNext) { - writer.process(encoder.fromRow(iter.next())) - } - } catch { - case e: Throwable => - writer.close(e) - throw e - } - writer.close(null) - } else { - writer.close(null) + schema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + ForeachInternalWriter(writer, encoder) + } +} + +case class ForeachInternalWriter[T: Encoder]( + writer: ForeachWriter[T], encoder: ExpressionEncoder[T]) + extends StreamWriter with SupportsWriteInternalRow { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { + ForeachWriterFactory(writer, encoder) + } +} + +case class ForeachWriterFactory[T: Encoder](writer: ForeachWriter[T], encoder: ExpressionEncoder[T]) + extends DataWriterFactory[InternalRow] { + override def createDataWriter(partitionId: Int, attemptNumber: Int): ForeachDataWriter[T] = { + new ForeachDataWriter(writer, encoder, partitionId) + } +} + +class ForeachDataWriter[T : Encoder]( + private var writer: ForeachWriter[T], encoder: ExpressionEncoder[T], partitionId: Int) + extends DataWriter[InternalRow] { + private val initialEpochId: Long = { + // Start with the microbatch ID. If it's not there, we're in continuous execution, + // so get the start epoch. + // This ID will be incremented as commits happen. + TaskContext.get().getLocalProperty(MicroBatchExecution.BATCH_ID_KEY) match { + case null => TaskContext.get().getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + case batch => batch.toLong + } + } + private var currentEpochId = initialEpochId + + // The lifecycle of the ForeachWriter is incompatible with the lifecycle of DataSourceV2 writers. + // Unfortunately, we cannot migrate ForeachWriter, as its implementations live in user code. So + // we need a small state machine to shim between them. + // * CLOSED means close() has been called. + // * OPENED + private object WriterState extends Enumeration { + type WriterState = Value + val CLOSED, OPENED, OPENED_SKIP_PROCESSING = Value + } + import WriterState._ + + private var state = CLOSED + + private def openAndSetState(epochId: Long) = { + // Create a new writer by roundtripping through the serialization for compatibility. + // In the old API, a writer instantiation would never get reused. + val byteStream = new ByteArrayOutputStream() --- End diff -- You're right; this suggestion is what we really want.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org