Kontinuation commented on code in PR #522:
URL: https://github.com/apache/sedona-db/pull/522#discussion_r2697293848


##########
rust/sedona-spatial-join/src/utils/spill.rs:
##########
@@ -0,0 +1,314 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::{fs::File, io::BufReader, sync::Arc};
+
+use arrow::ipc::{
+    reader::StreamReader,
+    writer::{IpcWriteOptions, StreamWriter},
+};
+use arrow_array::RecordBatch;
+use arrow_schema::SchemaRef;
+use datafusion::config::SpillCompression;
+use datafusion_common::{DataFusionError, Result};
+use datafusion_execution::{disk_manager::RefCountedTempFile, 
runtime_env::RuntimeEnv};
+use datafusion_physical_plan::metrics::SpillMetrics;
+
+use crate::utils::arrow_utils::{compact_batch, get_record_batch_memory_size};
+
+/// Generic Arrow IPC stream spill writer for [`RecordBatch`].
+///
+/// Shared between multiple components so spill metrics are updated 
consistently.
+pub(crate) struct RecordBatchSpillWriter {
+    in_progress_file: RefCountedTempFile,
+    writer: StreamWriter<File>,
+    metrics: SpillMetrics,
+    batch_size_threshold: Option<usize>,
+}
+
+impl RecordBatchSpillWriter {
+    pub fn try_new(
+        env: Arc<RuntimeEnv>,
+        schema: SchemaRef,
+        request_description: &str,
+        compression: SpillCompression,
+        metrics: SpillMetrics,
+        batch_size_threshold: Option<usize>,
+    ) -> Result<Self> {
+        let in_progress_file = 
env.disk_manager.create_tmp_file(request_description)?;
+        let file = File::create(in_progress_file.path())?;
+
+        let mut write_options = IpcWriteOptions::default();
+        write_options = 
write_options.try_with_compression(compression.into())?;
+
+        let writer = StreamWriter::try_new_with_options(file, schema.as_ref(), 
write_options)?;
+        metrics.spill_file_count.add(1);
+
+        Ok(Self {
+            in_progress_file,
+            writer,
+            metrics,
+            batch_size_threshold,
+        })
+    }
+
+    /// Write a record batch to the spill file.
+    ///
+    /// If `batch_size_threshold` is configured and the in-memory size of the 
batch exceeds the
+    /// threshold, this will automatically split the batch into smaller slices 
and (optionally)
+    /// compact each slice before writing.
+    pub fn write_batch(&mut self, batch: &RecordBatch) -> Result<()> {
+        let num_rows = batch.num_rows();
+        if num_rows == 0 {
+            // Preserve "empty batch" semantics: callers may rely on spilling 
and reading back a
+            // zero-row batch (e.g. as a sentinel for an empty stream).
+            return self.write_one_batch(batch);
+        }
+
+        let rows_per_split = self.calculate_rows_per_split(batch, num_rows)?;
+        if rows_per_split < num_rows {
+            let mut offset = 0;
+            while offset < num_rows {
+                let length = std::cmp::min(rows_per_split, num_rows - offset);
+                let slice = batch.slice(offset, length);
+                let compacted = compact_batch(slice)?;
+                self.write_one_batch(&compacted)?;
+                offset += length;
+            }
+        } else {
+            self.write_one_batch(batch)?;
+        }
+        Ok(())
+    }
+
+    fn calculate_rows_per_split(&self, batch: &RecordBatch, num_rows: usize) 
-> Result<usize> {
+        let Some(threshold) = self.batch_size_threshold else {
+            return Ok(num_rows);
+        };
+        if threshold == 0 {
+            return Ok(num_rows);
+        }
+
+        let batch_size = get_record_batch_memory_size(batch)?;
+        if batch_size <= threshold {
+            return Ok(num_rows);
+        }
+
+        let num_splits = batch_size.div_ceil(threshold);
+        let rows = num_rows.div_ceil(num_splits);
+        Ok(std::cmp::max(1, rows))
+    }
+
+    fn write_one_batch(&mut self, batch: &RecordBatch) -> Result<()> {
+        self.writer.write(batch).map_err(|e| {
+            DataFusionError::Execution(format!(
+                "Failed to write RecordBatch to spill file {:?}: {}",
+                self.in_progress_file.path(),
+                e
+            ))
+        })?;
+
+        self.metrics.spilled_rows.add(batch.num_rows());
+        self.metrics
+            .spilled_bytes
+            .add(get_record_batch_memory_size(batch)?);
+
+        Ok(())
+    }
+
+    pub fn finish(mut self) -> Result<RefCountedTempFile> {
+        self.writer.finish()?;
+
+        let mut in_progress_file = self.in_progress_file;
+        in_progress_file.update_disk_usage()?;
+        let size = in_progress_file.current_disk_usage();
+        self.metrics.spilled_bytes.add(size as usize);

Review Comment:
   Removed the spilled_bytes.add for each batch and leave this one as is.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to