tustvold commented on code in PR #6049:
URL: https://github.com/apache/arrow-datafusion/pull/6049#discussion_r1179709793
##########
datafusion/core/src/physical_plan/memory.rs:
##########
@@ -223,15 +245,365 @@ impl RecordBatchStream for MemoryStream {
}
}
+/// Execution plan for writing record batches to an in-memory table.
+pub struct MemoryWriteExec {
+ /// Input plan that produces the record batches to be written.
+ input: Arc<dyn ExecutionPlan>,
+ /// Reference to the MemTable's partition data.
+ batches: Vec<PartitionData>,
+ /// Schema describing the structure of the data.
+ schema: SchemaRef,
+}
+
+impl fmt::Debug for MemoryWriteExec {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "schema: {:?}", self.schema)
+ }
+}
+
+impl ExecutionPlan for MemoryWriteExec {
+ /// Return a reference to Any that can be used for downcasting
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ /// Get the schema for this execution plan
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+
+ fn output_partitioning(&self) -> Partitioning {
+ Partitioning::UnknownPartitioning(
+ self.input.output_partitioning().partition_count(),
+ )
+ }
+
+ fn benefits_from_input_partitioning(&self) -> bool {
+ false
+ }
+
+ fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
+ self.input.output_ordering()
+ }
+
+ fn required_input_distribution(&self) -> Vec<Distribution> {
+ // If the partition count of the MemTable is one, we want to require
SinglePartition
+ // since it would induce better plans in plan optimizer.
+ if self.batches.len() == 1 {
+ vec![Distribution::SinglePartition]
+ } else {
+ vec![Distribution::UnspecifiedDistribution]
+ }
+ }
+
+ fn maintains_input_order(&self) -> Vec<bool> {
+ // In theory, if MemTable partition count equals the input plans
output partition count,
+ // the Execution plan can preserve the order inside the partitions.
+ vec![self.batches.len() ==
self.input.output_partitioning().partition_count()]
+ }
+
+ fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
+ vec![self.input.clone()]
+ }
+
+ fn with_new_children(
+ self: Arc<Self>,
+ children: Vec<Arc<dyn ExecutionPlan>>,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ Ok(Arc::new(MemoryWriteExec::try_new(
+ children[0].clone(),
+ self.batches.clone(),
+ self.schema.clone(),
+ )?))
+ }
+
+ /// Execute the plan and return a stream of record batches for the
specified partition.
+ /// Depending on the number of input partitions and MemTable partitions,
it will choose
+ /// either a less lock acquiring or a locked implementation.
+ fn execute(
+ &self,
+ partition: usize,
+ context: Arc<TaskContext>,
+ ) -> Result<SendableRecordBatchStream> {
+ let batch_count = self.batches.len();
+ let data = self.input.execute(partition, context)?;
+ if batch_count >= self.input.output_partitioning().partition_count() {
+ // If the number of input partitions matches the number of
MemTable partitions,
+ // use a lightweight implementation that doesn't utilize as many
locks.
+ let table_partition = self.batches[partition].clone();
+ Ok(Box::pin(MemorySinkOneToOneStream::try_new(
+ table_partition,
+ data,
+ self.schema.clone(),
+ )?))
+ } else {
+ // Otherwise, use the locked implementation.
+ let table_partition = self.batches[partition %
batch_count].clone();
+ Ok(Box::pin(MemorySinkStream::try_new(
+ table_partition,
+ data,
+ self.schema.clone(),
+ )?))
+ }
+ }
+
+ fn fmt_as(
+ &self,
+ t: DisplayFormatType,
+ f: &mut std::fmt::Formatter,
+ ) -> std::fmt::Result {
+ match t {
+ DisplayFormatType::Default => {
+ write!(
+ f,
+ "MemoryWriteExec: partitions={}, input_partition={}",
+ self.batches.len(),
+ self.input.output_partitioning().partition_count()
+ )
+ }
+ }
+ }
+
+ fn statistics(&self) -> Statistics {
+ Statistics::default()
+ }
+}
+
+impl MemoryWriteExec {
+ /// Create a new execution plan for reading in-memory record batches
+ /// The provided `schema` should not have the projection applied.
+ pub fn try_new(
+ plan: Arc<dyn ExecutionPlan>,
+ batches: Vec<Arc<RwLock<Vec<RecordBatch>>>>,
+ schema: SchemaRef,
+ ) -> Result<Self> {
+ Ok(Self {
+ input: plan,
+ batches,
+ schema,
+ })
+ }
+}
+
+/// This object encodes the different states of the [`MemorySinkStream`] when
+/// processing record batches.
+enum MemorySinkStreamState {
+ /// The stream is pulling data from the input.
+ Pull,
+ /// The stream is writing data to the table partition.
+ Write { maybe_batch: Option<RecordBatch> },
+}
+
+/// A stream that saves record batches in memory-backed storage.
+/// Can work even when multiple input partitions map to the same table
+/// partition, achieves buffer exclusivity by locking before writing.
+pub struct MemorySinkStream {
+ /// Stream of record batches to be inserted into the memory table.
+ data: SendableRecordBatchStream,
+ /// Memory table partition that stores the record batches.
+ table_partition: PartitionData,
+ /// Schema representing the structure of the data.
+ schema: SchemaRef,
+ /// State of the iterator when processing multiple polls.
+ state: MemorySinkStreamState,
+}
+
+impl MemorySinkStream {
+ /// Create a new `MemorySinkStream` with the provided parameters.
+ pub fn try_new(
+ table_partition: PartitionData,
+ data: SendableRecordBatchStream,
+ schema: SchemaRef,
+ ) -> Result<Self> {
+ Ok(Self {
+ table_partition,
+ data,
+ schema,
+ state: MemorySinkStreamState::Pull,
+ })
+ }
+
+ /// Implementation of the `poll_next` method. Continuously polls the record
+ /// batch stream, switching between the Pull and Write states. In case of
+ /// an error, returns the error immediately.
+ fn poll_next_impl(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Option<Result<RecordBatch>>> {
+ loop {
+ match &mut self.state {
+ MemorySinkStreamState::Pull => {
+ // Pull data from the input stream.
+ if let Some(result) =
ready!(self.data.as_mut().poll_next(cx)) {
+ match result {
+ Ok(batch) => {
+ // Switch to the Write state with the received
batch.
+ self.state = MemorySinkStreamState::Write {
+ maybe_batch: Some(batch),
+ }
+ }
+ Err(e) => return Poll::Ready(Some(Err(e))), //
Return the error immediately.
+ }
+ } else {
+ return Poll::Ready(None); // If the input stream is
exhausted, return None.
+ }
+ }
+ MemorySinkStreamState::Write { maybe_batch } => {
+ // Acquire a write lock on the table partition.
+ let mut partition =
+
ready!(self.table_partition.write().boxed().poll_unpin(cx));
+ if let Some(b) = mem::take(maybe_batch) {
+ partition.push(b); // Insert the batch into the table
partition.
+ }
+ self.state = MemorySinkStreamState::Pull; // Switch back
to the Pull state.
+ }
+ }
+ }
+ }
+}
+
+impl Stream for MemorySinkStream {
+ type Item = Result<RecordBatch>;
+
+ fn poll_next(
+ mut self: std::pin::Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ self.poll_next_impl(cx)
+ }
+}
+
+impl RecordBatchStream for MemorySinkStream {
+ /// Get the schema
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+}
+
+/// This object encodes the different states of the
[`MemorySinkOneToOneStream`]
+/// when processing record batches.
+enum MemorySinkOneToOneStreamState {
+ /// The `Acquire` variant represents the state where the
[`MemorySinkOneToOneStream`]
+ /// is waiting to acquire the write lock on the shared partition to store
the record batches.
+ Acquire,
+
+ /// The `Pull` variant represents the state where the
[`MemorySinkOneToOneStream`] has
+ /// acquired the write lock on the shared partition and can pull record
batches from
+ /// the input stream to store in the partition.
+ Pull {
+ /// The `partition` field contains an [`OwnedRwLockWriteGuard`] which
wraps the
+ /// shared partition, providing exclusive write access to the
underlying `Vec<RecordBatch>`.
+ partition: OwnedRwLockWriteGuard<Vec<RecordBatch>>,
+ },
+}
+
+/// A stream that saves record batches in memory-backed storage.
+/// Assumes that every table partition has at most one corresponding input
+/// partition, so it locks the table partition only once.
+pub struct MemorySinkOneToOneStream {
Review Comment:
Same here, I would keep these private so they can be tweaked later, e.g. to
just use futures::stream::unfold
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]