This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new 9047d99f6b [arrow-cast] Support cast boolean from/to string view
(#6822)
9047d99f6b is described below
commit 9047d99f6bf87582532ee6ed0acb3f2d5f889f11
Author: Tai Le Manh <[email protected]>
AuthorDate: Tue Dec 3 15:55:55 2024 +0700
[arrow-cast] Support cast boolean from/to string view (#6822)
Signed-off-by: Tai Le Manh <[email protected]>
---
arrow-cast/src/cast/mod.rs | 31 ++++++++++++++++++++++++++++---
arrow-cast/src/cast/string.rs | 37 ++++++++++++++++++++++++++++---------
2 files changed, 56 insertions(+), 12 deletions(-)
diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs
index 5192f4a8ac..0a44392a24 100644
--- a/arrow-cast/src/cast/mod.rs
+++ b/arrow-cast/src/cast/mod.rs
@@ -197,13 +197,18 @@ pub fn can_cast_types(from_type: &DataType, to_type:
&DataType) -> bool {
(Struct(_), _) => false,
(_, Struct(_)) => false,
(_, Boolean) => {
- DataType::is_integer(from_type) ||
- DataType::is_floating(from_type)
+ DataType::is_integer(from_type)
+ || DataType::is_floating(from_type)
+ || from_type == &Utf8View
|| from_type == &Utf8
|| from_type == &LargeUtf8
}
(Boolean, _) => {
- DataType::is_integer(to_type) || DataType::is_floating(to_type) ||
to_type == &Utf8 || to_type == &LargeUtf8
+ DataType::is_integer(to_type)
+ || DataType::is_floating(to_type)
+ || to_type == &Utf8View
+ || to_type == &Utf8
+ || to_type == &LargeUtf8
}
(Binary, LargeBinary | Utf8 | LargeUtf8 | FixedSizeBinary(_) |
BinaryView | Utf8View ) => true,
@@ -1200,6 +1205,7 @@ pub fn cast_with_options(
Float16 => cast_numeric_to_bool::<Float16Type>(array),
Float32 => cast_numeric_to_bool::<Float32Type>(array),
Float64 => cast_numeric_to_bool::<Float64Type>(array),
+ Utf8View => cast_utf8view_to_boolean(array, cast_options),
Utf8 => cast_utf8_to_boolean::<i32>(array, cast_options),
LargeUtf8 => cast_utf8_to_boolean::<i64>(array, cast_options),
_ => Err(ArrowError::CastError(format!(
@@ -1218,6 +1224,7 @@ pub fn cast_with_options(
Float16 => cast_bool_to_numeric::<Float16Type>(array,
cast_options),
Float32 => cast_bool_to_numeric::<Float32Type>(array,
cast_options),
Float64 => cast_bool_to_numeric::<Float64Type>(array,
cast_options),
+ Utf8View => value_to_string_view(array, cast_options),
Utf8 => value_to_string::<i32>(array, cast_options),
LargeUtf8 => value_to_string::<i64>(array, cast_options),
_ => Err(ArrowError::CastError(format!(
@@ -3840,6 +3847,14 @@ mod tests {
assert_eq!(*as_boolean_array(&casted), expected);
}
+ #[test]
+ fn test_cast_utf8view_to_bool() {
+ let strings = StringViewArray::from(vec!["true", "false", "invalid", "
Y ", ""]);
+ let casted = cast(&strings, &DataType::Boolean).unwrap();
+ let expected = BooleanArray::from(vec![Some(true), Some(false), None,
Some(true), None]);
+ assert_eq!(*as_boolean_array(&casted), expected);
+ }
+
#[test]
fn test_cast_with_options_utf8_to_bool() {
let strings = StringArray::from(vec!["true", "false", "invalid", " Y
", ""]);
@@ -3871,6 +3886,16 @@ mod tests {
assert!(!c.is_valid(2));
}
+ #[test]
+ fn test_cast_bool_to_utf8view() {
+ let array = BooleanArray::from(vec![Some(true), Some(false), None]);
+ let b = cast(&array, &DataType::Utf8View).unwrap();
+ let c = b.as_any().downcast_ref::<StringViewArray>().unwrap();
+ assert_eq!("true", c.value(0));
+ assert_eq!("false", c.value(1));
+ assert!(!c.is_valid(2));
+ }
+
#[test]
fn test_cast_bool_to_utf8() {
let array = BooleanArray::from(vec![Some(true), Some(false), None]);
diff --git a/arrow-cast/src/cast/string.rs b/arrow-cast/src/cast/string.rs
index 07366a785a..7f22c4fd64 100644
--- a/arrow-cast/src/cast/string.rs
+++ b/arrow-cast/src/cast/string.rs
@@ -368,19 +368,14 @@ pub(crate) fn cast_binary_to_string<O: OffsetSizeTrait>(
}
}
-/// Casts Utf8 to Boolean
-pub(crate) fn cast_utf8_to_boolean<OffsetSize>(
- from: &dyn Array,
+/// Casts string to boolean
+fn cast_string_to_boolean<'a, StrArray>(
+ array: &StrArray,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
- OffsetSize: OffsetSizeTrait,
+ StrArray: StringArrayType<'a>,
{
- let array = from
- .as_any()
- .downcast_ref::<GenericStringArray<OffsetSize>>()
- .unwrap();
-
let output_array = array
.iter()
.map(|value| match value {
@@ -402,3 +397,27 @@ where
Ok(Arc::new(output_array))
}
+
+pub(crate) fn cast_utf8_to_boolean<OffsetSize>(
+ from: &dyn Array,
+ cast_options: &CastOptions,
+) -> Result<ArrayRef, ArrowError>
+where
+ OffsetSize: OffsetSizeTrait,
+{
+ let array = from
+ .as_any()
+ .downcast_ref::<GenericStringArray<OffsetSize>>()
+ .unwrap();
+
+ cast_string_to_boolean(&array, cast_options)
+}
+
+pub(crate) fn cast_utf8view_to_boolean(
+ from: &dyn Array,
+ cast_options: &CastOptions,
+) -> Result<ArrayRef, ArrowError> {
+ let array = from.as_any().downcast_ref::<StringViewArray>().unwrap();
+
+ cast_string_to_boolean(&array, cast_options)
+}