Kontinuation commented on code in PR #522:
URL: https://github.com/apache/sedona-db/pull/522#discussion_r2698277127


##########
rust/sedona-spatial-join/src/evaluated_batch/spill.rs:
##########
@@ -0,0 +1,794 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::Arc;
+
+use arrow::array::Float64Array;
+use arrow_array::{Array, RecordBatch, StructArray};
+use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef};
+use datafusion::config::SpillCompression;
+use datafusion_common::{DataFusionError, Result, ScalarValue};
+use datafusion_execution::{disk_manager::RefCountedTempFile, 
runtime_env::RuntimeEnv};
+use datafusion_expr::ColumnarValue;
+use datafusion_physical_plan::metrics::SpillMetrics;
+use sedona_common::sedona_internal_err;
+use sedona_schema::datatypes::SedonaType;
+
+use crate::{
+    evaluated_batch::EvaluatedBatch,
+    operand_evaluator::EvaluatedGeometryArray,
+    utils::spill::{RecordBatchSpillReader, RecordBatchSpillWriter},
+};
+
+/// Writer for spilling evaluated batches to disk
+pub struct EvaluatedBatchSpillWriter {
+    /// The temporary spill file being written to
+    inner: RecordBatchSpillWriter,
+
+    /// Schema of the spilled record batches. It is augmented from the schema 
of original record batches
+    /// The spill_schema has 4 fields:
+    /// * `data`: StructArray containing the original record batch columns
+    /// * `geom`: geometry array in storage format
+    /// * `dist`: distance field
+    spill_schema: Schema,
+    /// Inner fields of the "data" StructArray in the spilled record batches
+    data_inner_fields: Fields,
+}
+
+impl EvaluatedBatchSpillWriter {
+    /// Create a new SpillWriter
+    pub fn try_new(
+        env: Arc<RuntimeEnv>,
+        schema: SchemaRef,
+        sedona_type: &SedonaType,
+        request_description: &str,
+        compression: SpillCompression,
+        metrics: SpillMetrics,
+        batch_size_threshold: Option<usize>,
+    ) -> Result<Self> {
+        // Construct schema of record batches to be written. The written 
batches is augmented from the original record batches.
+        let data_inner_fields = schema.fields().clone();
+        let data_struct_field =
+            Field::new("data", DataType::Struct(data_inner_fields.clone()), 
false);
+        let geom_field = sedona_type.to_storage_field("geom", true)?;
+        let dist_field = Field::new("dist", DataType::Float64, true);
+        let spill_schema = Schema::new(vec![data_struct_field, geom_field, 
dist_field]);
+
+        // Create spill file
+        let inner = RecordBatchSpillWriter::try_new(
+            env,
+            Arc::new(spill_schema.clone()),
+            request_description,
+            compression,
+            metrics,
+            batch_size_threshold,
+        )?;
+
+        Ok(Self {
+            inner,
+            spill_schema,
+            data_inner_fields,
+        })
+    }
+
+    /// Append an EvaluatedBatch to the spill file
+    pub fn append(&mut self, evaluated_batch: &EvaluatedBatch) -> Result<()> {
+        let record_batch = self.spilled_record_batch(evaluated_batch)?;
+
+        // Splitting/compaction and spill bytes/rows metrics are handled by 
`RecordBatchSpillWriter`.
+        self.inner.write_batch(&record_batch)?;
+        Ok(())
+    }
+
+    /// Finish writing and return the temporary file
+    pub fn finish(self) -> Result<RefCountedTempFile> {
+        self.inner.finish()
+    }
+
+    fn spilled_record_batch(&self, evaluated_batch: &EvaluatedBatch) -> 
Result<RecordBatch> {
+        let num_rows = evaluated_batch.num_rows();
+
+        // Store the original data batch into a StructArray
+        let data_batch = &evaluated_batch.batch;
+        let data_arrays = data_batch.columns().to_vec();
+        let data_struct_array =
+            StructArray::try_new(self.data_inner_fields.clone(), data_arrays, 
None)?;
+
+        // Store dist into a Float64Array
+        let mut dist_builder = 
arrow::array::Float64Builder::with_capacity(num_rows);
+        let geom_array = &evaluated_batch.geom_array;
+        match &geom_array.distance {
+            Some(ColumnarValue::Scalar(scalar)) => match scalar {
+                ScalarValue::Float64(dist_value) => {
+                    for _ in 0..num_rows {
+                        dist_builder.append_option(*dist_value);
+                    }
+                }
+                _ => {
+                    return Err(DataFusionError::Internal(
+                        "Distance columnar value is not a 
Float64Array".to_string(),
+                    ));
+                }
+            },
+            Some(ColumnarValue::Array(array)) => {
+                let float_array = array
+                    .as_any()
+                    .downcast_ref::<arrow::array::Float64Array>()
+                    .unwrap();
+                dist_builder.append_array(float_array);
+            }
+            None => {
+                for _ in 0..num_rows {
+                    dist_builder.append_null();
+                }
+            }
+        }
+        let dist_array = dist_builder.finish();
+
+        // Assemble the final spilled RecordBatch
+        let columns = vec![
+            Arc::new(data_struct_array) as Arc<dyn arrow::array::Array>,
+            Arc::clone(&geom_array.geometry_array),
+            Arc::new(dist_array) as Arc<dyn arrow::array::Array>,
+        ];
+        let spilled_record_batch =
+            RecordBatch::try_new(Arc::new(self.spill_schema.clone()), 
columns)?;
+        Ok(spilled_record_batch)
+    }
+}
+/// Reader for reading spilled evaluated batches from disk
+pub struct EvaluatedBatchSpillReader {
+    inner: RecordBatchSpillReader,
+}
+impl EvaluatedBatchSpillReader {
+    /// Create a new SpillReader
+    pub fn try_new(temp_file: &RefCountedTempFile) -> Result<Self> {
+        Ok(Self {
+            inner: RecordBatchSpillReader::try_new(temp_file)?,
+        })
+    }
+
+    /// Get the schema of the spilled data
+    pub fn schema(&self) -> SchemaRef {
+        self.inner.schema()
+    }
+
+    /// Read the next EvaluatedBatch from the spill file
+    #[allow(unused)]
+    pub fn next_batch(&mut self) -> Option<Result<EvaluatedBatch>> {
+        self.next_raw_batch()
+            .map(|record_batch| 
record_batch.and_then(spilled_batch_to_evaluated_batch))
+    }
+
+    /// Read the next raw RecordBatch from the spill file
+    pub fn next_raw_batch(&mut self) -> Option<Result<RecordBatch>> {
+        self.inner.next_batch()
+    }
+}
+
+pub(crate) fn spilled_batch_to_evaluated_batch(
+    record_batch: RecordBatch,
+) -> Result<EvaluatedBatch> {
+    // Extract the data struct array (column 0) and convert back to the 
original RecordBatch
+    let data_array = record_batch
+        .column(0)
+        .as_any()
+        .downcast_ref::<StructArray>()
+        .ok_or_else(|| {
+            DataFusionError::Internal("Expected data column to be a 
StructArray".to_string())
+        })?;
+
+    let data_schema = Arc::new(Schema::new(match data_array.data_type() {
+        DataType::Struct(fields) => fields.clone(),
+        _ => {
+            return Err(DataFusionError::Internal(
+                "Expected data column to have Struct data type".to_string(),
+            ))
+        }
+    }));
+
+    let data_columns = (0..data_array.num_columns())
+        .map(|i| Arc::clone(data_array.column(i)))
+        .collect::<Vec<_>>();
+
+    let batch = RecordBatch::try_new(data_schema, data_columns)?;
+
+    // Extract the geometry array (column 1)
+    let geom_array = Arc::clone(record_batch.column(1));
+
+    // Determine the SedonaType from the geometry field in the record batch 
schema
+    let schema = record_batch.schema();
+    let geom_field = schema.field(1);
+    let sedona_type = SedonaType::from_storage_field(geom_field)?;

Review Comment:
   Fixed this by defining spill field indexes as constants



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to