This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new fe536495de Extract parquet statistics from f16 columns, add
`ScalarValue::Float16` (#10763)
fe536495de is described below
commit fe536495de8116d41f62c550fd000df9c3d98aab
Author: Lordworms <[email protected]>
AuthorDate: Mon Jun 3 14:27:49 2024 -0700
Extract parquet statistics from f16 columns, add `ScalarValue::Float16`
(#10763)
* Extract parquet statistics from f16 columns
* Update datafusion/common/src/scalar/mod.rs
Co-authored-by: Andrew Lamb <[email protected]>
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/common/src/scalar/mod.rs | 39 +++++++++++++--
.../datasource/physical_plan/parquet/statistics.rs | 14 +++++-
datafusion/core/tests/parquet/arrow_statistics.rs | 43 ++++++++++++++---
datafusion/core/tests/parquet/mod.rs | 55 +++++++++++++++++-----
datafusion/proto-common/src/to_proto/mod.rs | 5 ++
datafusion/sql/src/unparser/expr.rs | 4 ++
6 files changed, 136 insertions(+), 24 deletions(-)
diff --git a/datafusion/common/src/scalar/mod.rs
b/datafusion/common/src/scalar/mod.rs
index d2c6513eef..ba006247cd 100644
--- a/datafusion/common/src/scalar/mod.rs
+++ b/datafusion/common/src/scalar/mod.rs
@@ -18,13 +18,13 @@
//! [`ScalarValue`]: stores single values
mod struct_builder;
-
use std::borrow::Borrow;
use std::cmp::Ordering;
use std::collections::{HashSet, VecDeque};
use std::convert::Infallible;
use std::fmt;
use std::hash::Hash;
+use std::hash::Hasher;
use std::iter::repeat;
use std::str::FromStr;
use std::sync::Arc;
@@ -55,6 +55,7 @@ use arrow::{
use arrow_buffer::Buffer;
use arrow_schema::{UnionFields, UnionMode};
+use half::f16;
pub use struct_builder::ScalarStructBuilder;
/// A dynamically typed, nullable single value.
@@ -192,6 +193,8 @@ pub enum ScalarValue {
Null,
/// true or false value
Boolean(Option<bool>),
+ /// 16bit float
+ Float16(Option<f16>),
/// 32bit float
Float32(Option<f32>),
/// 64bit float
@@ -285,6 +288,12 @@ pub enum ScalarValue {
Dictionary(Box<DataType>, Box<ScalarValue>),
}
+impl Hash for Fl<f16> {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ self.0.to_bits().hash(state);
+ }
+}
+
// manual implementation of `PartialEq`
impl PartialEq for ScalarValue {
fn eq(&self, other: &Self) -> bool {
@@ -307,7 +316,12 @@ impl PartialEq for ScalarValue {
(Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(),
_ => v1.eq(v2),
},
+ (Float16(v1), Float16(v2)) => match (v1, v2) {
+ (Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(),
+ _ => v1.eq(v2),
+ },
(Float32(_), _) => false,
+ (Float16(_), _) => false,
(Float64(v1), Float64(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(),
_ => v1.eq(v2),
@@ -425,7 +439,12 @@ impl PartialOrd for ScalarValue {
(Some(f1), Some(f2)) => Some(f1.total_cmp(f2)),
_ => v1.partial_cmp(v2),
},
+ (Float16(v1), Float16(v2)) => match (v1, v2) {
+ (Some(f1), Some(f2)) => Some(f1.total_cmp(f2)),
+ _ => v1.partial_cmp(v2),
+ },
(Float32(_), _) => None,
+ (Float16(_), _) => None,
(Float64(v1), Float64(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => Some(f1.total_cmp(f2)),
_ => v1.partial_cmp(v2),
@@ -637,6 +656,7 @@ impl std::hash::Hash for ScalarValue {
s.hash(state)
}
Boolean(v) => v.hash(state),
+ Float16(v) => v.map(Fl).hash(state),
Float32(v) => v.map(Fl).hash(state),
Float64(v) => v.map(Fl).hash(state),
Int8(v) => v.hash(state),
@@ -1082,6 +1102,7 @@ impl ScalarValue {
ScalarValue::TimestampNanosecond(_, tz_opt) => {
DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone())
}
+ ScalarValue::Float16(_) => DataType::Float16,
ScalarValue::Float32(_) => DataType::Float32,
ScalarValue::Float64(_) => DataType::Float64,
ScalarValue::Utf8(_) => DataType::Utf8,
@@ -1276,6 +1297,7 @@ impl ScalarValue {
match self {
ScalarValue::Boolean(v) => v.is_none(),
ScalarValue::Null => true,
+ ScalarValue::Float16(v) => v.is_none(),
ScalarValue::Float32(v) => v.is_none(),
ScalarValue::Float64(v) => v.is_none(),
ScalarValue::Decimal128(v, _, _) => v.is_none(),
@@ -1522,6 +1544,7 @@ impl ScalarValue {
}
DataType::Null => ScalarValue::iter_to_null_array(scalars)?,
DataType::Boolean => build_array_primitive!(BooleanArray, Boolean),
+ DataType::Float16 => build_array_primitive!(Float16Array, Float16),
DataType::Float32 => build_array_primitive!(Float32Array, Float32),
DataType::Float64 => build_array_primitive!(Float64Array, Float64),
DataType::Int8 => build_array_primitive!(Int8Array, Int8),
@@ -1682,8 +1705,7 @@ impl ScalarValue {
// not supported if the TimeUnit is not valid (Time32 can
// only be used with Second and Millisecond, Time64 only
// with Microsecond and Nanosecond)
- DataType::Float16
- | DataType::Time32(TimeUnit::Microsecond)
+ DataType::Time32(TimeUnit::Microsecond)
| DataType::Time32(TimeUnit::Nanosecond)
| DataType::Time64(TimeUnit::Second)
| DataType::Time64(TimeUnit::Millisecond)
@@ -1700,7 +1722,6 @@ impl ScalarValue {
);
}
};
-
Ok(array)
}
@@ -1921,6 +1942,9 @@ impl ScalarValue {
ScalarValue::Float32(e) => {
build_array_from_option!(Float32, Float32Array, e, size)
}
+ ScalarValue::Float16(e) => {
+ build_array_from_option!(Float16, Float16Array, e, size)
+ }
ScalarValue::Int8(e) => build_array_from_option!(Int8, Int8Array,
e, size),
ScalarValue::Int16(e) => build_array_from_option!(Int16,
Int16Array, e, size),
ScalarValue::Int32(e) => build_array_from_option!(Int32,
Int32Array, e, size),
@@ -2595,6 +2619,9 @@ impl ScalarValue {
ScalarValue::Boolean(val) => {
eq_array_primitive!(array, index, BooleanArray, val)?
}
+ ScalarValue::Float16(val) => {
+ eq_array_primitive!(array, index, Float16Array, val)?
+ }
ScalarValue::Float32(val) => {
eq_array_primitive!(array, index, Float32Array, val)?
}
@@ -2738,6 +2765,7 @@ impl ScalarValue {
+ match self {
ScalarValue::Null
| ScalarValue::Boolean(_)
+ | ScalarValue::Float16(_)
| ScalarValue::Float32(_)
| ScalarValue::Float64(_)
| ScalarValue::Decimal128(_, _, _)
@@ -3022,6 +3050,7 @@ impl TryFrom<&DataType> for ScalarValue {
fn try_from(data_type: &DataType) -> Result<Self> {
Ok(match data_type {
DataType::Boolean => ScalarValue::Boolean(None),
+ DataType::Float16 => ScalarValue::Float16(None),
DataType::Float64 => ScalarValue::Float64(None),
DataType::Float32 => ScalarValue::Float32(None),
DataType::Int8 => ScalarValue::Int8(None),
@@ -3147,6 +3176,7 @@ impl fmt::Display for ScalarValue {
write!(f, "{v:?},{p:?},{s:?}")?;
}
ScalarValue::Boolean(e) => format_option!(f, e)?,
+ ScalarValue::Float16(e) => format_option!(f, e)?,
ScalarValue::Float32(e) => format_option!(f, e)?,
ScalarValue::Float64(e) => format_option!(f, e)?,
ScalarValue::Int8(e) => format_option!(f, e)?,
@@ -3260,6 +3290,7 @@ impl fmt::Debug for ScalarValue {
ScalarValue::Decimal128(_, _, _) => write!(f,
"Decimal128({self})"),
ScalarValue::Decimal256(_, _, _) => write!(f,
"Decimal256({self})"),
ScalarValue::Boolean(_) => write!(f, "Boolean({self})"),
+ ScalarValue::Float16(_) => write!(f, "Float16({self})"),
ScalarValue::Float32(_) => write!(f, "Float32({self})"),
ScalarValue::Float64(_) => write!(f, "Float64({self})"),
ScalarValue::Int8(_) => write!(f, "Int8({self})"),
diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs
b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs
index e7e6360c25..6c738cfe03 100644
--- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs
+++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs
@@ -25,11 +25,11 @@ use arrow_schema::{Field, FieldRef, Schema};
use datafusion_common::{
internal_datafusion_err, internal_err, plan_err, Result, ScalarValue,
};
+use half::f16;
use parquet::file::metadata::ParquetMetaData;
use parquet::file::statistics::Statistics as ParquetStatistics;
use parquet::schema::types::SchemaDescriptor;
use std::sync::Arc;
-
// Convert the bytes array to i128.
// The endian of the input bytes array must be big-endian.
pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 {
@@ -39,6 +39,14 @@ pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 {
i128::from_be_bytes(sign_extend_be(b))
}
+// Convert the bytes array to f16
+pub(crate) fn from_bytes_to_f16(b: &[u8]) -> Option<f16> {
+ match b {
+ [low, high] => Some(f16::from_be_bytes([*high, *low])),
+ _ => None,
+ }
+}
+
// Copy from arrow-rs
//
https://github.com/apache/arrow-rs/blob/733b7e7fd1e8c43a404c3ce40ecf741d493c21b4/parquet/src/arrow/buffer/bit_util.rs#L55
// Convert the byte slice to fixed length byte array with the length of 16
@@ -196,6 +204,9 @@ macro_rules! get_statistic {
value,
))
}
+ Some(DataType::Float16) => {
+
Some(ScalarValue::Float16(from_bytes_to_f16(s.$bytes_func())))
+ }
_ => None,
}
}
@@ -344,7 +355,6 @@ impl<'a> StatisticsConverter<'a> {
column_name
);
};
-
Ok(Self {
column_name,
statistics_type,
diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs
b/datafusion/core/tests/parquet/arrow_statistics.rs
index aa5fc7c34c..c2bf75c8f0 100644
--- a/datafusion/core/tests/parquet/arrow_statistics.rs
+++ b/datafusion/core/tests/parquet/arrow_statistics.rs
@@ -21,6 +21,7 @@
use std::fs::File;
use std::sync::Arc;
+use crate::parquet::{struct_array, Scenario};
use arrow::compute::kernels::cast_utils::Parser;
use arrow::datatypes::{
Date32Type, Date64Type, TimestampMicrosecondType, TimestampMillisecondType,
@@ -28,21 +29,21 @@ use arrow::datatypes::{
};
use arrow_array::{
make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array,
Date64Array,
- Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array,
Int16Array,
- Int32Array, Int64Array, Int8Array, LargeStringArray, RecordBatch,
StringArray,
- TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray,
- TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
+ Decimal128Array, FixedSizeBinaryArray, Float16Array, Float32Array,
Float64Array,
+ Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray,
RecordBatch,
+ StringArray, TimestampMicrosecondArray, TimestampMillisecondArray,
+ TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array,
+ UInt64Array, UInt8Array,
};
use arrow_schema::{DataType, Field, Schema};
use datafusion::datasource::physical_plan::parquet::{
RequestedStatistics, StatisticsConverter,
};
+use half::f16;
use parquet::arrow::arrow_reader::{ArrowReaderBuilder,
ParquetRecordBatchReaderBuilder};
use parquet::arrow::ArrowWriter;
use parquet::file::properties::{EnabledStatistics, WriterProperties};
-use crate::parquet::{struct_array, Scenario};
-
use super::make_test_file_rg;
// TEST HELPERS
@@ -1203,6 +1204,36 @@ async fn test_float64() {
.run();
}
+#[tokio::test]
+async fn test_float16() {
+ // This creates a parquet file of 1 column "f"
+ // file has 4 record batches, each has 5 rows. They will be saved into 4
row groups
+ let reader = TestReader {
+ scenario: Scenario::Float16,
+ row_per_group: 5,
+ };
+
+ Test {
+ reader: reader.build().await,
+ expected_min: Arc::new(Float16Array::from(
+ vec![-5.0, -4.0, -0.0, 5.0]
+ .into_iter()
+ .map(f16::from_f32)
+ .collect::<Vec<_>>(),
+ )),
+ expected_max: Arc::new(Float16Array::from(
+ vec![-1.0, 0.0, 4.0, 9.0]
+ .into_iter()
+ .map(f16::from_f32)
+ .collect::<Vec<_>>(),
+ )),
+ expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]),
+ expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]),
+ column_name: "f",
+ }
+ .run();
+}
+
#[tokio::test]
async fn test_decimal() {
// This creates a parquet file of 1 column "decimal_col" with decimal data
type and precicion 9, scale 2
diff --git a/datafusion/core/tests/parquet/mod.rs
b/datafusion/core/tests/parquet/mod.rs
index bfb6e8e555..e951644f2c 100644
--- a/datafusion/core/tests/parquet/mod.rs
+++ b/datafusion/core/tests/parquet/mod.rs
@@ -19,20 +19,17 @@
use arrow::array::Decimal128Array;
use arrow::{
array::{
- Array, ArrayRef, BinaryArray, Date32Array, Date64Array,
FixedSizeBinaryArray,
- Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
StringArray,
- TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray,
- TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array,
UInt8Array,
+ make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array,
Date64Array,
+ DictionaryArray, FixedSizeBinaryArray, Float16Array, Float32Array,
Float64Array,
+ Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray,
StringArray,
+ StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
+ TimestampNanosecondArray, TimestampSecondArray, UInt16Array,
UInt32Array,
+ UInt64Array, UInt8Array,
},
- datatypes::{DataType, Field, Schema},
+ datatypes::{DataType, Field, Int32Type, Int8Type, Schema},
record_batch::RecordBatch,
util::pretty::pretty_format_batches,
};
-use arrow_array::types::{Int32Type, Int8Type};
-use arrow_array::{
- make_array, BooleanArray, DictionaryArray, Float32Array, LargeStringArray,
- StructArray,
-};
use chrono::{Datelike, Duration, TimeDelta};
use datafusion::{
datasource::{physical_plan::ParquetExec, provider_as_source,
TableProvider},
@@ -40,11 +37,11 @@ use datafusion::{
prelude::{ParquetReadOptions, SessionConfig, SessionContext},
};
use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder};
+use half::f16;
use parquet::arrow::ArrowWriter;
use parquet::file::properties::WriterProperties;
use std::sync::Arc;
use tempfile::NamedTempFile;
-
mod arrow_statistics;
mod custom_reader;
mod file_statistics;
@@ -79,6 +76,7 @@ enum Scenario {
/// 7 Rows, for each i8, i16, i32, i64, u8, u16, u32, u64, f32, f64
/// -MIN, -100, -1, 0, 1, 100, MAX
NumericLimits,
+ Float16,
Float64,
Decimal,
DecimalBloomFilterInt32,
@@ -542,6 +540,12 @@ fn make_f64_batch(v: Vec<f64>) -> RecordBatch {
RecordBatch::try_new(schema, vec![array.clone()]).unwrap()
}
+fn make_f16_batch(v: Vec<f16>) -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Float16,
true)]));
+ let array = Arc::new(Float16Array::from(v)) as ArrayRef;
+ RecordBatch::try_new(schema, vec![array.clone()]).unwrap()
+}
+
/// Return record batch with decimal vector
///
/// Columns are named
@@ -897,6 +901,34 @@ fn create_data_batch(scenario: Scenario) ->
Vec<RecordBatch> {
Scenario::NumericLimits => {
vec![make_numeric_limit_batch()]
}
+ Scenario::Float16 => {
+ vec![
+ make_f16_batch(
+ vec![-5.0, -4.0, -3.0, -2.0, -1.0]
+ .into_iter()
+ .map(f16::from_f32)
+ .collect(),
+ ),
+ make_f16_batch(
+ vec![-4.0, -3.0, -2.0, -1.0, 0.0]
+ .into_iter()
+ .map(f16::from_f32)
+ .collect(),
+ ),
+ make_f16_batch(
+ vec![0.0, 1.0, 2.0, 3.0, 4.0]
+ .into_iter()
+ .map(f16::from_f32)
+ .collect(),
+ ),
+ make_f16_batch(
+ vec![5.0, 6.0, 7.0, 8.0, 9.0]
+ .into_iter()
+ .map(f16::from_f32)
+ .collect(),
+ ),
+ ]
+ }
Scenario::Float64 => {
vec![
make_f64_batch(vec![-5.0, -4.0, -3.0, -2.0, -1.0]),
@@ -1087,7 +1119,6 @@ async fn make_test_file_rg(scenario: Scenario,
row_per_group: usize) -> NamedTem
.build();
let batches = create_data_batch(scenario);
-
let schema = batches[0].schema();
let mut writer = ArrowWriter::try_new(&mut output_file, schema,
Some(props)).unwrap();
diff --git a/datafusion/proto-common/src/to_proto/mod.rs
b/datafusion/proto-common/src/to_proto/mod.rs
index f160bc40af..a92deaa88b 100644
--- a/datafusion/proto-common/src/to_proto/mod.rs
+++ b/datafusion/proto-common/src/to_proto/mod.rs
@@ -294,6 +294,11 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue {
ScalarValue::Boolean(val) => {
create_proto_scalar(val.as_ref(), &data_type, |s|
Value::BoolValue(*s))
}
+ ScalarValue::Float16(val) => {
+ create_proto_scalar(val.as_ref(), &data_type, |s| {
+ Value::Float32Value((*s).into())
+ })
+ }
ScalarValue::Float32(val) => {
create_proto_scalar(val.as_ref(), &data_type, |s|
Value::Float32Value(*s))
}
diff --git a/datafusion/sql/src/unparser/expr.rs
b/datafusion/sql/src/unparser/expr.rs
index 1ba6638e73..3efbe2ace6 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -643,6 +643,10 @@ impl Unparser<'_> {
Ok(ast::Expr::Value(ast::Value::Boolean(b.to_owned())))
}
ScalarValue::Boolean(None) =>
Ok(ast::Expr::Value(ast::Value::Null)),
+ ScalarValue::Float16(Some(f)) => {
+ Ok(ast::Expr::Value(ast::Value::Number(f.to_string(), false)))
+ }
+ ScalarValue::Float16(None) =>
Ok(ast::Expr::Value(ast::Value::Null)),
ScalarValue::Float32(Some(f)) => {
Ok(ast::Expr::Value(ast::Value::Number(f.to_string(), false)))
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]