This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new fe536495de Extract parquet statistics from f16 columns, add 
`ScalarValue::Float16` (#10763)
fe536495de is described below

commit fe536495de8116d41f62c550fd000df9c3d98aab
Author: Lordworms <[email protected]>
AuthorDate: Mon Jun 3 14:27:49 2024 -0700

    Extract parquet statistics from f16 columns, add `ScalarValue::Float16` 
(#10763)
    
    * Extract parquet statistics from f16 columns
    
    * Update datafusion/common/src/scalar/mod.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/common/src/scalar/mod.rs                | 39 +++++++++++++--
 .../datasource/physical_plan/parquet/statistics.rs | 14 +++++-
 datafusion/core/tests/parquet/arrow_statistics.rs  | 43 ++++++++++++++---
 datafusion/core/tests/parquet/mod.rs               | 55 +++++++++++++++++-----
 datafusion/proto-common/src/to_proto/mod.rs        |  5 ++
 datafusion/sql/src/unparser/expr.rs                |  4 ++
 6 files changed, 136 insertions(+), 24 deletions(-)

diff --git a/datafusion/common/src/scalar/mod.rs 
b/datafusion/common/src/scalar/mod.rs
index d2c6513eef..ba006247cd 100644
--- a/datafusion/common/src/scalar/mod.rs
+++ b/datafusion/common/src/scalar/mod.rs
@@ -18,13 +18,13 @@
 //! [`ScalarValue`]: stores single  values
 
 mod struct_builder;
-
 use std::borrow::Borrow;
 use std::cmp::Ordering;
 use std::collections::{HashSet, VecDeque};
 use std::convert::Infallible;
 use std::fmt;
 use std::hash::Hash;
+use std::hash::Hasher;
 use std::iter::repeat;
 use std::str::FromStr;
 use std::sync::Arc;
@@ -55,6 +55,7 @@ use arrow::{
 use arrow_buffer::Buffer;
 use arrow_schema::{UnionFields, UnionMode};
 
+use half::f16;
 pub use struct_builder::ScalarStructBuilder;
 
 /// A dynamically typed, nullable single value.
@@ -192,6 +193,8 @@ pub enum ScalarValue {
     Null,
     /// true or false value
     Boolean(Option<bool>),
+    /// 16bit float
+    Float16(Option<f16>),
     /// 32bit float
     Float32(Option<f32>),
     /// 64bit float
@@ -285,6 +288,12 @@ pub enum ScalarValue {
     Dictionary(Box<DataType>, Box<ScalarValue>),
 }
 
+impl Hash for Fl<f16> {
+    fn hash<H: Hasher>(&self, state: &mut H) {
+        self.0.to_bits().hash(state);
+    }
+}
+
 // manual implementation of `PartialEq`
 impl PartialEq for ScalarValue {
     fn eq(&self, other: &Self) -> bool {
@@ -307,7 +316,12 @@ impl PartialEq for ScalarValue {
                 (Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(),
                 _ => v1.eq(v2),
             },
+            (Float16(v1), Float16(v2)) => match (v1, v2) {
+                (Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(),
+                _ => v1.eq(v2),
+            },
             (Float32(_), _) => false,
+            (Float16(_), _) => false,
             (Float64(v1), Float64(v2)) => match (v1, v2) {
                 (Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(),
                 _ => v1.eq(v2),
@@ -425,7 +439,12 @@ impl PartialOrd for ScalarValue {
                 (Some(f1), Some(f2)) => Some(f1.total_cmp(f2)),
                 _ => v1.partial_cmp(v2),
             },
+            (Float16(v1), Float16(v2)) => match (v1, v2) {
+                (Some(f1), Some(f2)) => Some(f1.total_cmp(f2)),
+                _ => v1.partial_cmp(v2),
+            },
             (Float32(_), _) => None,
+            (Float16(_), _) => None,
             (Float64(v1), Float64(v2)) => match (v1, v2) {
                 (Some(f1), Some(f2)) => Some(f1.total_cmp(f2)),
                 _ => v1.partial_cmp(v2),
@@ -637,6 +656,7 @@ impl std::hash::Hash for ScalarValue {
                 s.hash(state)
             }
             Boolean(v) => v.hash(state),
+            Float16(v) => v.map(Fl).hash(state),
             Float32(v) => v.map(Fl).hash(state),
             Float64(v) => v.map(Fl).hash(state),
             Int8(v) => v.hash(state),
@@ -1082,6 +1102,7 @@ impl ScalarValue {
             ScalarValue::TimestampNanosecond(_, tz_opt) => {
                 DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone())
             }
+            ScalarValue::Float16(_) => DataType::Float16,
             ScalarValue::Float32(_) => DataType::Float32,
             ScalarValue::Float64(_) => DataType::Float64,
             ScalarValue::Utf8(_) => DataType::Utf8,
@@ -1276,6 +1297,7 @@ impl ScalarValue {
         match self {
             ScalarValue::Boolean(v) => v.is_none(),
             ScalarValue::Null => true,
+            ScalarValue::Float16(v) => v.is_none(),
             ScalarValue::Float32(v) => v.is_none(),
             ScalarValue::Float64(v) => v.is_none(),
             ScalarValue::Decimal128(v, _, _) => v.is_none(),
@@ -1522,6 +1544,7 @@ impl ScalarValue {
             }
             DataType::Null => ScalarValue::iter_to_null_array(scalars)?,
             DataType::Boolean => build_array_primitive!(BooleanArray, Boolean),
+            DataType::Float16 => build_array_primitive!(Float16Array, Float16),
             DataType::Float32 => build_array_primitive!(Float32Array, Float32),
             DataType::Float64 => build_array_primitive!(Float64Array, Float64),
             DataType::Int8 => build_array_primitive!(Int8Array, Int8),
@@ -1682,8 +1705,7 @@ impl ScalarValue {
             // not supported if the TimeUnit is not valid (Time32 can
             // only be used with Second and Millisecond, Time64 only
             // with Microsecond and Nanosecond)
-            DataType::Float16
-            | DataType::Time32(TimeUnit::Microsecond)
+            DataType::Time32(TimeUnit::Microsecond)
             | DataType::Time32(TimeUnit::Nanosecond)
             | DataType::Time64(TimeUnit::Second)
             | DataType::Time64(TimeUnit::Millisecond)
@@ -1700,7 +1722,6 @@ impl ScalarValue {
                 );
             }
         };
-
         Ok(array)
     }
 
@@ -1921,6 +1942,9 @@ impl ScalarValue {
             ScalarValue::Float32(e) => {
                 build_array_from_option!(Float32, Float32Array, e, size)
             }
+            ScalarValue::Float16(e) => {
+                build_array_from_option!(Float16, Float16Array, e, size)
+            }
             ScalarValue::Int8(e) => build_array_from_option!(Int8, Int8Array, 
e, size),
             ScalarValue::Int16(e) => build_array_from_option!(Int16, 
Int16Array, e, size),
             ScalarValue::Int32(e) => build_array_from_option!(Int32, 
Int32Array, e, size),
@@ -2595,6 +2619,9 @@ impl ScalarValue {
             ScalarValue::Boolean(val) => {
                 eq_array_primitive!(array, index, BooleanArray, val)?
             }
+            ScalarValue::Float16(val) => {
+                eq_array_primitive!(array, index, Float16Array, val)?
+            }
             ScalarValue::Float32(val) => {
                 eq_array_primitive!(array, index, Float32Array, val)?
             }
@@ -2738,6 +2765,7 @@ impl ScalarValue {
             + match self {
                 ScalarValue::Null
                 | ScalarValue::Boolean(_)
+                | ScalarValue::Float16(_)
                 | ScalarValue::Float32(_)
                 | ScalarValue::Float64(_)
                 | ScalarValue::Decimal128(_, _, _)
@@ -3022,6 +3050,7 @@ impl TryFrom<&DataType> for ScalarValue {
     fn try_from(data_type: &DataType) -> Result<Self> {
         Ok(match data_type {
             DataType::Boolean => ScalarValue::Boolean(None),
+            DataType::Float16 => ScalarValue::Float16(None),
             DataType::Float64 => ScalarValue::Float64(None),
             DataType::Float32 => ScalarValue::Float32(None),
             DataType::Int8 => ScalarValue::Int8(None),
@@ -3147,6 +3176,7 @@ impl fmt::Display for ScalarValue {
                 write!(f, "{v:?},{p:?},{s:?}")?;
             }
             ScalarValue::Boolean(e) => format_option!(f, e)?,
+            ScalarValue::Float16(e) => format_option!(f, e)?,
             ScalarValue::Float32(e) => format_option!(f, e)?,
             ScalarValue::Float64(e) => format_option!(f, e)?,
             ScalarValue::Int8(e) => format_option!(f, e)?,
@@ -3260,6 +3290,7 @@ impl fmt::Debug for ScalarValue {
             ScalarValue::Decimal128(_, _, _) => write!(f, 
"Decimal128({self})"),
             ScalarValue::Decimal256(_, _, _) => write!(f, 
"Decimal256({self})"),
             ScalarValue::Boolean(_) => write!(f, "Boolean({self})"),
+            ScalarValue::Float16(_) => write!(f, "Float16({self})"),
             ScalarValue::Float32(_) => write!(f, "Float32({self})"),
             ScalarValue::Float64(_) => write!(f, "Float64({self})"),
             ScalarValue::Int8(_) => write!(f, "Int8({self})"),
diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs 
b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs
index e7e6360c25..6c738cfe03 100644
--- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs
+++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs
@@ -25,11 +25,11 @@ use arrow_schema::{Field, FieldRef, Schema};
 use datafusion_common::{
     internal_datafusion_err, internal_err, plan_err, Result, ScalarValue,
 };
+use half::f16;
 use parquet::file::metadata::ParquetMetaData;
 use parquet::file::statistics::Statistics as ParquetStatistics;
 use parquet::schema::types::SchemaDescriptor;
 use std::sync::Arc;
-
 // Convert the bytes array to i128.
 // The endian of the input bytes array must be big-endian.
 pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 {
@@ -39,6 +39,14 @@ pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 {
     i128::from_be_bytes(sign_extend_be(b))
 }
 
+// Convert the bytes array to f16
+pub(crate) fn from_bytes_to_f16(b: &[u8]) -> Option<f16> {
+    match b {
+        [low, high] => Some(f16::from_be_bytes([*high, *low])),
+        _ => None,
+    }
+}
+
 // Copy from arrow-rs
 // 
https://github.com/apache/arrow-rs/blob/733b7e7fd1e8c43a404c3ce40ecf741d493c21b4/parquet/src/arrow/buffer/bit_util.rs#L55
 // Convert the byte slice to fixed length byte array with the length of 16
@@ -196,6 +204,9 @@ macro_rules! get_statistic {
                             value,
                         ))
                     }
+                    Some(DataType::Float16) => {
+                        
Some(ScalarValue::Float16(from_bytes_to_f16(s.$bytes_func())))
+                    }
                     _ => None,
                 }
             }
@@ -344,7 +355,6 @@ impl<'a> StatisticsConverter<'a> {
                 column_name
             );
         };
-
         Ok(Self {
             column_name,
             statistics_type,
diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs 
b/datafusion/core/tests/parquet/arrow_statistics.rs
index aa5fc7c34c..c2bf75c8f0 100644
--- a/datafusion/core/tests/parquet/arrow_statistics.rs
+++ b/datafusion/core/tests/parquet/arrow_statistics.rs
@@ -21,6 +21,7 @@
 use std::fs::File;
 use std::sync::Arc;
 
+use crate::parquet::{struct_array, Scenario};
 use arrow::compute::kernels::cast_utils::Parser;
 use arrow::datatypes::{
     Date32Type, Date64Type, TimestampMicrosecondType, TimestampMillisecondType,
@@ -28,21 +29,21 @@ use arrow::datatypes::{
 };
 use arrow_array::{
     make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, 
Date64Array,
-    Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array, 
Int16Array,
-    Int32Array, Int64Array, Int8Array, LargeStringArray, RecordBatch, 
StringArray,
-    TimestampMicrosecondArray, TimestampMillisecondArray, 
TimestampNanosecondArray,
-    TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
+    Decimal128Array, FixedSizeBinaryArray, Float16Array, Float32Array, 
Float64Array,
+    Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, 
RecordBatch,
+    StringArray, TimestampMicrosecondArray, TimestampMillisecondArray,
+    TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array,
+    UInt64Array, UInt8Array,
 };
 use arrow_schema::{DataType, Field, Schema};
 use datafusion::datasource::physical_plan::parquet::{
     RequestedStatistics, StatisticsConverter,
 };
+use half::f16;
 use parquet::arrow::arrow_reader::{ArrowReaderBuilder, 
ParquetRecordBatchReaderBuilder};
 use parquet::arrow::ArrowWriter;
 use parquet::file::properties::{EnabledStatistics, WriterProperties};
 
-use crate::parquet::{struct_array, Scenario};
-
 use super::make_test_file_rg;
 
 // TEST HELPERS
@@ -1203,6 +1204,36 @@ async fn test_float64() {
     .run();
 }
 
+#[tokio::test]
+async fn test_float16() {
+    // This creates a parquet file of 1 column "f"
+    // file has 4 record batches, each has 5 rows. They will be saved into 4 
row groups
+    let reader = TestReader {
+        scenario: Scenario::Float16,
+        row_per_group: 5,
+    };
+
+    Test {
+        reader: reader.build().await,
+        expected_min: Arc::new(Float16Array::from(
+            vec![-5.0, -4.0, -0.0, 5.0]
+                .into_iter()
+                .map(f16::from_f32)
+                .collect::<Vec<_>>(),
+        )),
+        expected_max: Arc::new(Float16Array::from(
+            vec![-1.0, 0.0, 4.0, 9.0]
+                .into_iter()
+                .map(f16::from_f32)
+                .collect::<Vec<_>>(),
+        )),
+        expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]),
+        expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]),
+        column_name: "f",
+    }
+    .run();
+}
+
 #[tokio::test]
 async fn test_decimal() {
     // This creates a parquet file of 1 column "decimal_col" with decimal data 
type and precicion 9, scale 2
diff --git a/datafusion/core/tests/parquet/mod.rs 
b/datafusion/core/tests/parquet/mod.rs
index bfb6e8e555..e951644f2c 100644
--- a/datafusion/core/tests/parquet/mod.rs
+++ b/datafusion/core/tests/parquet/mod.rs
@@ -19,20 +19,17 @@
 use arrow::array::Decimal128Array;
 use arrow::{
     array::{
-        Array, ArrayRef, BinaryArray, Date32Array, Date64Array, 
FixedSizeBinaryArray,
-        Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, 
StringArray,
-        TimestampMicrosecondArray, TimestampMillisecondArray, 
TimestampNanosecondArray,
-        TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, 
UInt8Array,
+        make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, 
Date64Array,
+        DictionaryArray, FixedSizeBinaryArray, Float16Array, Float32Array, 
Float64Array,
+        Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, 
StringArray,
+        StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
+        TimestampNanosecondArray, TimestampSecondArray, UInt16Array, 
UInt32Array,
+        UInt64Array, UInt8Array,
     },
-    datatypes::{DataType, Field, Schema},
+    datatypes::{DataType, Field, Int32Type, Int8Type, Schema},
     record_batch::RecordBatch,
     util::pretty::pretty_format_batches,
 };
-use arrow_array::types::{Int32Type, Int8Type};
-use arrow_array::{
-    make_array, BooleanArray, DictionaryArray, Float32Array, LargeStringArray,
-    StructArray,
-};
 use chrono::{Datelike, Duration, TimeDelta};
 use datafusion::{
     datasource::{physical_plan::ParquetExec, provider_as_source, 
TableProvider},
@@ -40,11 +37,11 @@ use datafusion::{
     prelude::{ParquetReadOptions, SessionConfig, SessionContext},
 };
 use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder};
+use half::f16;
 use parquet::arrow::ArrowWriter;
 use parquet::file::properties::WriterProperties;
 use std::sync::Arc;
 use tempfile::NamedTempFile;
-
 mod arrow_statistics;
 mod custom_reader;
 mod file_statistics;
@@ -79,6 +76,7 @@ enum Scenario {
     /// 7 Rows, for each i8, i16, i32, i64, u8, u16, u32, u64, f32, f64
     /// -MIN, -100, -1, 0, 1, 100, MAX
     NumericLimits,
+    Float16,
     Float64,
     Decimal,
     DecimalBloomFilterInt32,
@@ -542,6 +540,12 @@ fn make_f64_batch(v: Vec<f64>) -> RecordBatch {
     RecordBatch::try_new(schema, vec![array.clone()]).unwrap()
 }
 
+fn make_f16_batch(v: Vec<f16>) -> RecordBatch {
+    let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Float16, 
true)]));
+    let array = Arc::new(Float16Array::from(v)) as ArrayRef;
+    RecordBatch::try_new(schema, vec![array.clone()]).unwrap()
+}
+
 /// Return record batch with decimal vector
 ///
 /// Columns are named
@@ -897,6 +901,34 @@ fn create_data_batch(scenario: Scenario) -> 
Vec<RecordBatch> {
         Scenario::NumericLimits => {
             vec![make_numeric_limit_batch()]
         }
+        Scenario::Float16 => {
+            vec![
+                make_f16_batch(
+                    vec![-5.0, -4.0, -3.0, -2.0, -1.0]
+                        .into_iter()
+                        .map(f16::from_f32)
+                        .collect(),
+                ),
+                make_f16_batch(
+                    vec![-4.0, -3.0, -2.0, -1.0, 0.0]
+                        .into_iter()
+                        .map(f16::from_f32)
+                        .collect(),
+                ),
+                make_f16_batch(
+                    vec![0.0, 1.0, 2.0, 3.0, 4.0]
+                        .into_iter()
+                        .map(f16::from_f32)
+                        .collect(),
+                ),
+                make_f16_batch(
+                    vec![5.0, 6.0, 7.0, 8.0, 9.0]
+                        .into_iter()
+                        .map(f16::from_f32)
+                        .collect(),
+                ),
+            ]
+        }
         Scenario::Float64 => {
             vec![
                 make_f64_batch(vec![-5.0, -4.0, -3.0, -2.0, -1.0]),
@@ -1087,7 +1119,6 @@ async fn make_test_file_rg(scenario: Scenario, 
row_per_group: usize) -> NamedTem
         .build();
 
     let batches = create_data_batch(scenario);
-
     let schema = batches[0].schema();
 
     let mut writer = ArrowWriter::try_new(&mut output_file, schema, 
Some(props)).unwrap();
diff --git a/datafusion/proto-common/src/to_proto/mod.rs 
b/datafusion/proto-common/src/to_proto/mod.rs
index f160bc40af..a92deaa88b 100644
--- a/datafusion/proto-common/src/to_proto/mod.rs
+++ b/datafusion/proto-common/src/to_proto/mod.rs
@@ -294,6 +294,11 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue {
             ScalarValue::Boolean(val) => {
                 create_proto_scalar(val.as_ref(), &data_type, |s| 
Value::BoolValue(*s))
             }
+            ScalarValue::Float16(val) => {
+                create_proto_scalar(val.as_ref(), &data_type, |s| {
+                    Value::Float32Value((*s).into())
+                })
+            }
             ScalarValue::Float32(val) => {
                 create_proto_scalar(val.as_ref(), &data_type, |s| 
Value::Float32Value(*s))
             }
diff --git a/datafusion/sql/src/unparser/expr.rs 
b/datafusion/sql/src/unparser/expr.rs
index 1ba6638e73..3efbe2ace6 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -643,6 +643,10 @@ impl Unparser<'_> {
                 Ok(ast::Expr::Value(ast::Value::Boolean(b.to_owned())))
             }
             ScalarValue::Boolean(None) => 
Ok(ast::Expr::Value(ast::Value::Null)),
+            ScalarValue::Float16(Some(f)) => {
+                Ok(ast::Expr::Value(ast::Value::Number(f.to_string(), false)))
+            }
+            ScalarValue::Float16(None) => 
Ok(ast::Expr::Value(ast::Value::Null)),
             ScalarValue::Float32(Some(f)) => {
                 Ok(ast::Expr::Value(ast::Value::Number(f.to_string(), false)))
             }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to