This is an automated email from the ASF dual-hosted git repository.
jiayuliu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new f6908bf add support for f16 (#888)
f6908bf is described below
commit f6908bfed7f7a17187488c5c3995fb98a9961ac2
Author: Jiayu Liu <[email protected]>
AuthorDate: Mon Nov 29 21:08:13 2021 +0800
add support for f16 (#888)
---
arrow/Cargo.toml | 1 +
arrow/src/alloc/types.rs | 2 ++
arrow/src/array/array.rs | 4 ++--
arrow/src/array/data.rs | 17 ++++++++++-------
arrow/src/array/equal/mod.rs | 6 ++++--
arrow/src/array/mod.rs | 8 ++++++++
arrow/src/array/transform/mod.rs | 18 +++++++++---------
arrow/src/datatypes/native.rs | 11 +++++++++--
arrow/src/datatypes/types.rs | 2 ++
9 files changed, 47 insertions(+), 22 deletions(-)
diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml
index 694d31d..51e74b0 100644
--- a/arrow/Cargo.toml
+++ b/arrow/Cargo.toml
@@ -43,6 +43,7 @@ serde_json = { version = "1.0", features = ["preserve_order"]
}
indexmap = "1.6"
rand = { version = "0.8", optional = true }
num = "0.4"
+half = "1.8"
csv_crate = { version = "1.1", optional = true, package="csv" }
regex = "1.3"
lazy_static = "1.4"
diff --git a/arrow/src/alloc/types.rs b/arrow/src/alloc/types.rs
index 92a6107..026e124 100644
--- a/arrow/src/alloc/types.rs
+++ b/arrow/src/alloc/types.rs
@@ -16,6 +16,7 @@
// under the License.
use crate::datatypes::DataType;
+use half::f16;
/// A type that Rust's custom allocator knows how to allocate and deallocate.
/// This is implemented for all Arrow's physical types whose in-memory
representation
@@ -67,5 +68,6 @@ create_native!(
i64,
DataType::Int64 | DataType::Date64 | DataType::Time64(_) |
DataType::Timestamp(_, _)
);
+create_native!(f16, DataType::Float16);
create_native!(f32, DataType::Float32);
create_native!(f64, DataType::Float64);
diff --git a/arrow/src/array/array.rs b/arrow/src/array/array.rs
index fcf4647..34cdb73 100644
--- a/arrow/src/array/array.rs
+++ b/arrow/src/array/array.rs
@@ -240,7 +240,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef {
DataType::UInt16 => Arc::new(UInt16Array::from(data)) as ArrayRef,
DataType::UInt32 => Arc::new(UInt32Array::from(data)) as ArrayRef,
DataType::UInt64 => Arc::new(UInt64Array::from(data)) as ArrayRef,
- DataType::Float16 => panic!("Float16 datatype not supported"),
+ DataType::Float16 => Arc::new(Float16Array::from(data)) as ArrayRef,
DataType::Float32 => Arc::new(Float32Array::from(data)) as ArrayRef,
DataType::Float64 => Arc::new(Float64Array::from(data)) as ArrayRef,
DataType::Date32 => Arc::new(Date32Array::from(data)) as ArrayRef,
@@ -393,7 +393,7 @@ pub fn new_null_array(data_type: &DataType, length: usize)
-> ArrayRef {
DataType::UInt8 => new_null_sized_array::<UInt8Type>(data_type,
length),
DataType::Int16 => new_null_sized_array::<Int16Type>(data_type,
length),
DataType::UInt16 => new_null_sized_array::<UInt16Type>(data_type,
length),
- DataType::Float16 => unreachable!(),
+ DataType::Float16 => new_null_sized_array::<Float16Type>(data_type,
length),
DataType::Int32 => new_null_sized_array::<Int32Type>(data_type,
length),
DataType::UInt32 => new_null_sized_array::<UInt32Type>(data_type,
length),
DataType::Float32 => new_null_sized_array::<Float32Type>(data_type,
length),
diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs
index 40a8bee..da609b8 100644
--- a/arrow/src/array/data.rs
+++ b/arrow/src/array/data.rs
@@ -18,10 +18,6 @@
//! Contains `ArrayData`, a generic representation of Arrow array data which
encapsulates
//! common attributes and operations for Arrow array.
-use std::convert::TryInto;
-use std::mem;
-use std::sync::Arc;
-
use crate::datatypes::{DataType, IntervalUnit};
use crate::error::{ArrowError, Result};
use crate::{bitmap::Bitmap, datatypes::ArrowNativeType};
@@ -29,6 +25,10 @@ use crate::{
buffer::{Buffer, MutableBuffer},
util::bit_util,
};
+use half::f16;
+use std::convert::TryInto;
+use std::mem;
+use std::sync::Arc;
use super::equal::equal;
@@ -89,6 +89,10 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity:
usize) -> [MutableBuff
MutableBuffer::new(capacity * mem::size_of::<i64>()),
empty_buffer,
],
+ DataType::Float16 => [
+ MutableBuffer::new(capacity * mem::size_of::<f16>()),
+ empty_buffer,
+ ],
DataType::Float32 => [
MutableBuffer::new(capacity * mem::size_of::<f32>()),
empty_buffer,
@@ -178,7 +182,6 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity:
usize) -> [MutableBuff
],
_ => unreachable!(),
},
- DataType::Float16 => unreachable!(),
DataType::FixedSizeList(_, _) | DataType::Struct(_) => {
[empty_buffer, MutableBuffer::new(0)]
}
@@ -319,7 +322,7 @@ impl ArrayData {
buffers: Vec<Buffer>,
child_data: Vec<ArrayData>,
) -> Result<Self> {
- // Safetly justification: `validate` is (will be) called below
+ // Safety justification: `validate` is (will be) called below
let new_self = unsafe {
Self::new_unchecked(
data_type,
@@ -519,6 +522,7 @@ impl ArrayData {
| DataType::Int16
| DataType::Int32
| DataType::Int64
+ | DataType::Float16
| DataType::Float32
| DataType::Float64
| DataType::Date32
@@ -554,7 +558,6 @@ impl ArrayData {
DataType::Dictionary(_, data_type) => {
vec![Self::new_empty(data_type)]
}
- DataType::Float16 => unreachable!(),
};
// Data was constructed correctly above
diff --git a/arrow/src/array/equal/mod.rs b/arrow/src/array/equal/mod.rs
index 15d41a0..0e8d8bb 100644
--- a/arrow/src/array/equal/mod.rs
+++ b/arrow/src/array/equal/mod.rs
@@ -25,11 +25,11 @@ use super::{
GenericStringArray, MapArray, NullArray, OffsetSizeTrait, PrimitiveArray,
StringOffsetSizeTrait, StructArray,
};
-
use crate::{
buffer::Buffer,
datatypes::{ArrowPrimitiveType, DataType, IntervalUnit},
};
+use half::f16;
mod boolean;
mod decimal;
@@ -251,7 +251,9 @@ fn equal_values(
),
_ => unreachable!(),
},
- DataType::Float16 => unreachable!(),
+ DataType::Float16 => primitive_equal::<f16>(
+ lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
+ ),
DataType::Map(_, _) => {
list_equal::<i32>(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start,
rhs_start, len)
}
diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs
index 26b410e..4d9eaa7 100644
--- a/arrow/src/array/mod.rs
+++ b/arrow/src/array/mod.rs
@@ -192,6 +192,14 @@ pub type UInt64Array = PrimitiveArray<UInt64Type>;
///
/// # Example: Using `collect`
/// ```
+/// # use arrow::array::Float16Array;
+/// use half::f16;
+/// let arr : Float16Array = [Some(f16::from_f64(1.0)),
Some(f16::from_f64(2.0))].into_iter().collect();
+/// ```
+pub type Float16Array = PrimitiveArray<Float16Type>;
+///
+/// # Example: Using `collect`
+/// ```
/// # use arrow::array::Float32Array;
/// let arr : Float32Array = [Some(1.0), Some(2.0)].into_iter().collect();
/// ```
diff --git a/arrow/src/array/transform/mod.rs b/arrow/src/array/transform/mod.rs
index 2c18848..9ad3dbf 100644
--- a/arrow/src/array/transform/mod.rs
+++ b/arrow/src/array/transform/mod.rs
@@ -15,20 +15,20 @@
// specific language governing permissions and limitations
// under the License.
+use super::{
+ data::{into_buffers, new_buffers},
+ ArrayData, ArrayDataBuilder,
+};
+use crate::array::StringOffsetSizeTrait;
use crate::{
buffer::MutableBuffer,
datatypes::DataType,
error::{ArrowError, Result},
util::bit_util,
};
+use half::f16;
use std::mem;
-use super::{
- data::{into_buffers, new_buffers},
- ArrayData, ArrayDataBuilder,
-};
-use crate::array::StringOffsetSizeTrait;
-
mod boolean;
mod fixed_binary;
mod list;
@@ -266,7 +266,7 @@ fn build_extend(array: &ArrayData) -> Extend {
DataType::Dictionary(_, _) => unreachable!("should use
build_extend_dictionary"),
DataType::Struct(_) => structure::build_extend(array),
DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array),
- DataType::Float16 => unreachable!(),
+ DataType::Float16 => primitive::build_extend::<f16>(array),
/*
DataType::FixedSizeList(_, _) => {}
DataType::Union(_) => {}
@@ -315,7 +315,7 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls {
},
DataType::Struct(_) => structure::extend_nulls,
DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls,
- DataType::Float16 => unreachable!(),
+ DataType::Float16 => primitive::extend_nulls::<f16>,
/*
DataType::FixedSizeList(_, _) => {}
DataType::Union(_) => {}
@@ -429,6 +429,7 @@ impl<'a> MutableArrayData<'a> {
| DataType::Int16
| DataType::Int32
| DataType::Int64
+ | DataType::Float16
| DataType::Float32
| DataType::Float64
| DataType::Date32
@@ -467,7 +468,6 @@ impl<'a> MutableArrayData<'a> {
}
// the dictionary type just appends keys and clones the values.
DataType::Dictionary(_, _) => vec![],
- DataType::Float16 => unreachable!(),
DataType::Struct(fields) => match capacities {
Capacities::Struct(capacity, Some(ref child_capacities)) => {
array_capacity = capacity;
diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs
index 6e8cf89..18d593b 100644
--- a/arrow/src/datatypes/native.rs
+++ b/arrow/src/datatypes/native.rs
@@ -15,9 +15,9 @@
// specific language governing permissions and limitations
// under the License.
-use serde_json::{Number, Value};
-
use super::DataType;
+use half::f16;
+use serde_json::{Number, Value};
/// Trait declaring any type that is serializable to JSON. This includes all
primitive types (bool, i32, etc.).
pub trait JsonSerializable: 'static {
@@ -293,6 +293,12 @@ impl ArrowNativeType for u64 {
}
}
+impl JsonSerializable for f16 {
+ fn into_json_value(self) -> Option<Value> {
+ Number::from_f64(f64::round(f64::from(self) * 1000.0) /
1000.0).map(Value::Number)
+ }
+}
+
impl JsonSerializable for f32 {
fn into_json_value(self) -> Option<Value> {
Number::from_f64(f64::round(self as f64 * 1000.0) /
1000.0).map(Value::Number)
@@ -305,6 +311,7 @@ impl JsonSerializable for f64 {
}
}
+impl ArrowNativeType for f16 {}
impl ArrowNativeType for f32 {}
impl ArrowNativeType for f64 {}
diff --git a/arrow/src/datatypes/types.rs b/arrow/src/datatypes/types.rs
index 30c9aae..2731e3d 100644
--- a/arrow/src/datatypes/types.rs
+++ b/arrow/src/datatypes/types.rs
@@ -16,6 +16,7 @@
// under the License.
use super::{ArrowPrimitiveType, DataType, IntervalUnit, TimeUnit};
+use half::f16;
// BooleanType is special: its bit-width is not the size of the primitive
type, and its `index`
// operation assumes bit-packing.
@@ -46,6 +47,7 @@ make_type!(UInt8Type, u8, DataType::UInt8);
make_type!(UInt16Type, u16, DataType::UInt16);
make_type!(UInt32Type, u32, DataType::UInt32);
make_type!(UInt64Type, u64, DataType::UInt64);
+make_type!(Float16Type, f16, DataType::Float16);
make_type!(Float32Type, f32, DataType::Float32);
make_type!(Float64Type, f64, DataType::Float64);
make_type!(