This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new ca7ea599d feat: `column_name` based index access for `RecordBatch` and
`StructArray` (#3458)
ca7ea599d is described below
commit ca7ea599d963a809c687f6aadc5729c452b11a29
Author: askoa <[email protected]>
AuthorDate: Fri Jan 6 23:50:34 2023 -0500
feat: `column_name` based index access for `RecordBatch` and `StructArray`
(#3458)
* feat: Add `column_name` based index access for `RecordBatch` and
`StructArray`
* change to simpler coding
Co-authored-by: askoa <askoa@local>
---
arrow-array/src/array/struct_array.rs | 39 +++++++++++++++++++++++++++++++++-
arrow-array/src/record_batch.rs | 40 +++++++++++++++++++++++++++++++++++
2 files changed, 78 insertions(+), 1 deletion(-)
diff --git a/arrow-array/src/array/struct_array.rs
b/arrow-array/src/array/struct_array.rs
index bf6489c13..dc949c8e4 100644
--- a/arrow-array/src/array/struct_array.rs
+++ b/arrow-array/src/array/struct_array.rs
@@ -20,7 +20,7 @@ use arrow_buffer::buffer::buffer_bin_or;
use arrow_buffer::Buffer;
use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType, Field};
-use std::any::Any;
+use std::{any::Any, ops::Index};
/// A nested array type where each child (called *field*) is represented by a
separate
/// array.
@@ -296,6 +296,23 @@ impl From<(Vec<(Field, ArrayRef)>, Buffer)> for
StructArray {
}
}
+impl Index<&str> for StructArray {
+ type Output = ArrayRef;
+
+ /// Get a reference to a column's array by name.
+ ///
+ /// Note: A schema can currently have duplicate field names, in which case
+ /// the first field will always be selected.
+ /// This issue will be addressed in
[ARROW-11178](https://issues.apache.org/jira/browse/ARROW-11178)
+ ///
+ /// # Panics
+ ///
+ /// Panics if the name is not in the schema.
+ fn index(&self, name: &str) -> &Self::Output {
+ self.column_by_name(name).unwrap()
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -352,6 +369,26 @@ mod tests {
assert_eq!(0, struct_array.offset());
}
+ /// validates that struct can be accessed using `column_name` as index
i.e. `struct_array["column_name"]`.
+ #[test]
+ fn test_struct_array_index_access() {
+ let boolean = Arc::new(BooleanArray::from(vec![false, false, true,
true]));
+ let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
+
+ let struct_array = StructArray::from(vec![
+ (
+ Field::new("b", DataType::Boolean, false),
+ boolean.clone() as ArrayRef,
+ ),
+ (
+ Field::new("c", DataType::Int32, false),
+ int.clone() as ArrayRef,
+ ),
+ ]);
+ assert_eq!(struct_array["b"].as_ref(), boolean.as_ref());
+ assert_eq!(struct_array["c"].as_ref(), int.as_ref());
+ }
+
/// validates that the in-memory representation follows [the
spec](https://arrow.apache.org/docs/format/Columnar.html#struct-layout)
#[test]
fn test_struct_array_from_vec() {
diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs
index ea0eb3853..72b567f75 100644
--- a/arrow-array/src/record_batch.rs
+++ b/arrow-array/src/record_batch.rs
@@ -20,6 +20,7 @@
use crate::{new_empty_array, Array, ArrayRef, StructArray};
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
+use std::ops::Index;
use std::sync::Arc;
/// Trait for types that can read `RecordBatch`'s.
@@ -288,6 +289,13 @@ impl RecordBatch {
&self.columns[index]
}
+ /// Get a reference to a column's array by name.
+ pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
+ self.schema()
+ .column_with_name(name)
+ .map(|(index, _)| &self.columns[index])
+ }
+
/// Get a reference to all columns in the record batch.
pub fn columns(&self) -> &[ArrayRef] {
&self.columns[..]
@@ -473,6 +481,19 @@ impl From<RecordBatch> for StructArray {
}
}
+impl Index<&str> for RecordBatch {
+ type Output = ArrayRef;
+
+ /// Get a reference to a column's array by name.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the name is not in the schema.
+ fn index(&self, name: &str) -> &Self::Output {
+ self.column_by_name(name).unwrap()
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -746,6 +767,25 @@ mod tests {
assert_eq!(batch1, batch2);
}
+ /// validates if the record batch can be accessed using `column_name` as
index i.e. `record_batch["column_name"]`
+ #[test]
+ fn record_batch_index_access() {
+ let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
+ let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
+ let schema1 = Schema::new(vec![
+ Field::new("id", DataType::Int32, false),
+ Field::new("val", DataType::Int32, false),
+ ]);
+ let record_batch = RecordBatch::try_new(
+ Arc::new(schema1),
+ vec![id_arr.clone(), val_arr.clone()],
+ )
+ .unwrap();
+
+ assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref());
+ assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref());
+ }
+
#[test]
fn record_batch_vals_ne() {
let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);