This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 7ba36b0123 Parquet: read/write f16 for Arrow (#5003)
7ba36b0123 is described below
commit 7ba36b012322e08b06184c806f8ba339181cebc1
Author: Jeffrey <[email protected]>
AuthorDate: Tue Nov 14 10:12:27 2023 +1100
Parquet: read/write f16 for Arrow (#5003)
* Support for read/write f16 Parquet to Arrow
* Update parquet/src/arrow/arrow_writer/mod.rs
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
* Update parquet/src/arrow/arrow_reader/mod.rs
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
* Update test with null version
* Fix schema tests and parsing for f16
* f16 for record api
* Handle NaN for f16 statistics writing
* Revert formatting changes
* Fix num trait
* Fix half feature
* Handle writing signed zero statistics
* Bump parquet-testing and read new f16 files for test
---------
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
---
parquet-testing | 2 +-
parquet/Cargo.toml | 1 +
parquet/regen.sh | 2 +-
.../src/arrow/array_reader/fixed_len_byte_array.rs | 17 +-
parquet/src/arrow/arrow_reader/mod.rs | 119 +++++++++++-
parquet/src/arrow/arrow_writer/mod.rs | 16 ++
parquet/src/arrow/schema/mod.rs | 17 +-
parquet/src/arrow/schema/primitive.rs | 10 +
parquet/src/basic.rs | 15 +-
parquet/src/column/writer/encoder.rs | 19 +-
parquet/src/column/writer/mod.rs | 204 ++++++++++++++++++++-
parquet/src/data_type.rs | 7 +
parquet/src/file/statistics.rs | 4 +
parquet/src/format.rs | 88 ++++++++-
parquet/src/record/api.rs | 88 ++++++++-
parquet/src/schema/parser.rs | 8 +
parquet/src/schema/printer.rs | 10 +
parquet/src/schema/types.rs | 44 +++++
18 files changed, 646 insertions(+), 25 deletions(-)
diff --git a/parquet-testing b/parquet-testing
index aafd3fc9df..506afff9b6 160000
--- a/parquet-testing
+++ b/parquet-testing
@@ -1 +1 @@
-Subproject commit aafd3fc9df431c2625a514fb46626e5614f1d199
+Subproject commit 506afff9b6957ffe10d08470d467867d43e1bb91
diff --git a/parquet/Cargo.toml b/parquet/Cargo.toml
index e5f5e1652b..bdcbcb81cf 100644
--- a/parquet/Cargo.toml
+++ b/parquet/Cargo.toml
@@ -66,6 +66,7 @@ tokio = { version = "1.0", optional = true, default-features
= false, features =
hashbrown = { version = "0.14", default-features = false }
twox-hash = { version = "1.6", default-features = false }
paste = { version = "1.0" }
+half = { version = "2.1", default-features = false, features = ["num-traits"] }
[dev-dependencies]
base64 = { version = "0.21", default-features = false, features = ["std"] }
diff --git a/parquet/regen.sh b/parquet/regen.sh
index b8c3549e23..9153963433 100755
--- a/parquet/regen.sh
+++ b/parquet/regen.sh
@@ -17,7 +17,7 @@
# specific language governing permissions and limitations
# under the License.
-REVISION=aeae80660c1d0c97314e9da837de1abdebd49c37
+REVISION=46cc3a0647d301bb9579ca8dd2cc356caf2a72d2
SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]:-$0}")" && pwd)"
diff --git a/parquet/src/arrow/array_reader/fixed_len_byte_array.rs
b/parquet/src/arrow/array_reader/fixed_len_byte_array.rs
index 3b1a50ebcc..b846997d36 100644
--- a/parquet/src/arrow/array_reader/fixed_len_byte_array.rs
+++ b/parquet/src/arrow/array_reader/fixed_len_byte_array.rs
@@ -27,13 +27,14 @@ use crate::column::reader::decoder::{ColumnValueDecoder,
ValuesBufferSlice};
use crate::errors::{ParquetError, Result};
use crate::schema::types::ColumnDescPtr;
use arrow_array::{
- ArrayRef, Decimal128Array, Decimal256Array, FixedSizeBinaryArray,
+ ArrayRef, Decimal128Array, Decimal256Array, FixedSizeBinaryArray,
Float16Array,
IntervalDayTimeArray, IntervalYearMonthArray,
};
use arrow_buffer::{i256, Buffer};
use arrow_data::ArrayDataBuilder;
use arrow_schema::{DataType as ArrowType, IntervalUnit};
use bytes::Bytes;
+use half::f16;
use std::any::Any;
use std::ops::Range;
use std::sync::Arc;
@@ -88,6 +89,14 @@ pub fn make_fixed_len_byte_array_reader(
));
}
}
+ ArrowType::Float16 => {
+ if byte_length != 2 {
+ return Err(general_err!(
+ "float 16 type must be 2 bytes, got {}",
+ byte_length
+ ));
+ }
+ }
_ => {
return Err(general_err!(
"invalid data type for fixed length byte array reader - {}",
@@ -208,6 +217,12 @@ impl ArrayReader for FixedLenByteArrayReader {
}
}
}
+ ArrowType::Float16 => Arc::new(
+ binary
+ .iter()
+ .map(|o| o.map(|b|
f16::from_le_bytes(b[..2].try_into().unwrap())))
+ .collect::<Float16Array>(),
+ ) as ArrayRef,
_ => Arc::new(binary) as ArrayRef,
};
diff --git a/parquet/src/arrow/arrow_reader/mod.rs
b/parquet/src/arrow/arrow_reader/mod.rs
index 16cdf2934e..b9e9d28984 100644
--- a/parquet/src/arrow/arrow_reader/mod.rs
+++ b/parquet/src/arrow/arrow_reader/mod.rs
@@ -712,13 +712,14 @@ mod tests {
use std::sync::Arc;
use bytes::Bytes;
+ use half::f16;
use num::PrimInt;
use rand::{thread_rng, Rng, RngCore};
use tempfile::tempfile;
use arrow_array::builder::*;
use arrow_array::cast::AsArray;
- use arrow_array::types::{Decimal128Type, Decimal256Type, DecimalType};
+ use arrow_array::types::{Decimal128Type, Decimal256Type, DecimalType,
Float16Type};
use arrow_array::*;
use arrow_array::{RecordBatch, RecordBatchReader};
use arrow_buffer::{i256, ArrowNativeType, Buffer};
@@ -924,6 +925,66 @@ mod tests {
.unwrap();
}
+ #[test]
+ fn test_float16_roundtrip() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("float16", ArrowDataType::Float16, false),
+ Field::new("float16-nullable", ArrowDataType::Float16, true),
+ ]));
+
+ let mut buf = Vec::with_capacity(1024);
+ let mut writer = ArrowWriter::try_new(&mut buf, schema.clone(), None)?;
+
+ let original = RecordBatch::try_new(
+ schema,
+ vec![
+ Arc::new(Float16Array::from_iter_values([
+ f16::EPSILON,
+ f16::MIN,
+ f16::MAX,
+ f16::NAN,
+ f16::INFINITY,
+ f16::NEG_INFINITY,
+ f16::ONE,
+ f16::NEG_ONE,
+ f16::ZERO,
+ f16::NEG_ZERO,
+ f16::E,
+ f16::PI,
+ f16::FRAC_1_PI,
+ ])),
+ Arc::new(Float16Array::from(vec![
+ None,
+ None,
+ None,
+ Some(f16::NAN),
+ Some(f16::INFINITY),
+ Some(f16::NEG_INFINITY),
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ Some(f16::FRAC_1_PI),
+ ])),
+ ],
+ )?;
+
+ writer.write(&original)?;
+ writer.close()?;
+
+ let mut reader = ParquetRecordBatchReader::try_new(Bytes::from(buf),
1024)?;
+ let ret = reader.next().unwrap()?;
+ assert_eq!(ret, original);
+
+ // Ensure can be downcast to the correct type
+ ret.column(0).as_primitive::<Float16Type>();
+ ret.column(1).as_primitive::<Float16Type>();
+
+ Ok(())
+ }
+
struct RandFixedLenGen {}
impl RandGen<FixedLenByteArrayType> for RandFixedLenGen {
@@ -1255,6 +1316,62 @@ mod tests {
}
}
+ #[test]
+ fn test_read_float16_nonzeros_file() {
+ use arrow_array::Float16Array;
+ let testdata = arrow::util::test_util::parquet_test_data();
+ // see https://github.com/apache/parquet-testing/pull/40
+ let path = format!("{testdata}/float16_nonzeros_and_nans.parquet");
+ let file = File::open(path).unwrap();
+ let mut record_reader = ParquetRecordBatchReader::try_new(file,
32).unwrap();
+
+ let batch = record_reader.next().unwrap().unwrap();
+ assert_eq!(batch.num_rows(), 8);
+ let col = batch
+ .column(0)
+ .as_any()
+ .downcast_ref::<Float16Array>()
+ .unwrap();
+
+ let f16_two = f16::ONE + f16::ONE;
+
+ assert_eq!(col.null_count(), 1);
+ assert!(col.is_null(0));
+ assert_eq!(col.value(1), f16::ONE);
+ assert_eq!(col.value(2), -f16_two);
+ assert!(col.value(3).is_nan());
+ assert_eq!(col.value(4), f16::ZERO);
+ assert!(col.value(4).is_sign_positive());
+ assert_eq!(col.value(5), f16::NEG_ONE);
+ assert_eq!(col.value(6), f16::NEG_ZERO);
+ assert!(col.value(6).is_sign_negative());
+ assert_eq!(col.value(7), f16_two);
+ }
+
+ #[test]
+ fn test_read_float16_zeros_file() {
+ use arrow_array::Float16Array;
+ let testdata = arrow::util::test_util::parquet_test_data();
+ // see https://github.com/apache/parquet-testing/pull/40
+ let path = format!("{testdata}/float16_zeros_and_nans.parquet");
+ let file = File::open(path).unwrap();
+ let mut record_reader = ParquetRecordBatchReader::try_new(file,
32).unwrap();
+
+ let batch = record_reader.next().unwrap().unwrap();
+ assert_eq!(batch.num_rows(), 3);
+ let col = batch
+ .column(0)
+ .as_any()
+ .downcast_ref::<Float16Array>()
+ .unwrap();
+
+ assert_eq!(col.null_count(), 1);
+ assert!(col.is_null(0));
+ assert_eq!(col.value(1), f16::ZERO);
+ assert!(col.value(1).is_sign_positive());
+ assert!(col.value(2).is_nan());
+ }
+
/// Parameters for single_column_reader_test
#[derive(Clone)]
struct TestOptions {
diff --git a/parquet/src/arrow/arrow_writer/mod.rs
b/parquet/src/arrow/arrow_writer/mod.rs
index eca1dea791..ea7b1eee99 100644
--- a/parquet/src/arrow/arrow_writer/mod.rs
+++ b/parquet/src/arrow/arrow_writer/mod.rs
@@ -771,6 +771,10 @@ fn write_leaf(writer: &mut ColumnWriter<'_>, levels:
&ArrayLevels) -> Result<usi
.unwrap();
get_decimal_256_array_slice(array, indices)
}
+ ArrowDataType::Float16 => {
+ let array = column.as_primitive::<Float16Type>();
+ get_float_16_array_slice(array, indices)
+ }
_ => {
return Err(ParquetError::NYI(
"Attempting to write an Arrow type that is not yet
implemented".to_string(),
@@ -867,6 +871,18 @@ fn get_decimal_256_array_slice(
values
}
+fn get_float_16_array_slice(
+ array: &arrow_array::Float16Array,
+ indices: &[usize],
+) -> Vec<FixedLenByteArray> {
+ let mut values = Vec::with_capacity(indices.len());
+ for i in indices {
+ let value = array.value(*i).to_le_bytes().to_vec();
+ values.push(FixedLenByteArray::from(ByteArray::from(value)));
+ }
+ values
+}
+
fn get_fsb_array_slice(
array: &arrow_array::FixedSizeBinaryArray,
indices: &[usize],
diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs
index d56cc42d43..4c350c4b1d 100644
--- a/parquet/src/arrow/schema/mod.rs
+++ b/parquet/src/arrow/schema/mod.rs
@@ -373,7 +373,12 @@ fn arrow_to_parquet_type(field: &Field) -> Result<Type> {
.with_repetition(repetition)
.with_id(id)
.build(),
- DataType::Float16 => Err(arrow_err!("Float16 arrays not supported")),
+ DataType::Float16 => Type::primitive_type_builder(name,
PhysicalType::FIXED_LEN_BYTE_ARRAY)
+ .with_repetition(repetition)
+ .with_id(id)
+ .with_logical_type(Some(LogicalType::Float16))
+ .with_length(2)
+ .build(),
DataType::Float32 => Type::primitive_type_builder(name,
PhysicalType::FLOAT)
.with_repetition(repetition)
.with_id(id)
@@ -604,9 +609,10 @@ mod tests {
REQUIRED INT32 uint8 (INTEGER(8,false));
REQUIRED INT32 uint16 (INTEGER(16,false));
REQUIRED INT32 int32;
- REQUIRED INT64 int64 ;
+ REQUIRED INT64 int64;
OPTIONAL DOUBLE double;
OPTIONAL FLOAT float;
+ OPTIONAL FIXED_LEN_BYTE_ARRAY (2) float16 (FLOAT16);
OPTIONAL BINARY string (UTF8);
OPTIONAL BINARY string_2 (STRING);
OPTIONAL BINARY json (JSON);
@@ -628,6 +634,7 @@ mod tests {
Field::new("int64", DataType::Int64, false),
Field::new("double", DataType::Float64, true),
Field::new("float", DataType::Float32, true),
+ Field::new("float16", DataType::Float16, true),
Field::new("string", DataType::Utf8, true),
Field::new("string_2", DataType::Utf8, true),
Field::new("json", DataType::Utf8, true),
@@ -1303,6 +1310,7 @@ mod tests {
REQUIRED INT64 int64;
OPTIONAL DOUBLE double;
OPTIONAL FLOAT float;
+ OPTIONAL FIXED_LEN_BYTE_ARRAY (2) float16 (FLOAT16);
OPTIONAL BINARY string (UTF8);
REPEATED BOOLEAN bools;
OPTIONAL INT32 date (DATE);
@@ -1339,6 +1347,7 @@ mod tests {
Field::new("int64", DataType::Int64, false),
Field::new("double", DataType::Float64, true),
Field::new("float", DataType::Float32, true),
+ Field::new("float16", DataType::Float16, true),
Field::new("string", DataType::Utf8, true),
Field::new_list(
"bools",
@@ -1398,6 +1407,7 @@ mod tests {
REQUIRED INT64 int64;
OPTIONAL DOUBLE double;
OPTIONAL FLOAT float;
+ OPTIONAL FIXED_LEN_BYTE_ARRAY (2) float16 (FLOAT16);
OPTIONAL BINARY string (STRING);
OPTIONAL GROUP bools (LIST) {
REPEATED GROUP list {
@@ -1448,6 +1458,7 @@ mod tests {
Field::new("int64", DataType::Int64, false),
Field::new("double", DataType::Float64, true),
Field::new("float", DataType::Float32, true),
+ Field::new("float16", DataType::Float16, true),
Field::new("string", DataType::Utf8, true),
Field::new_list(
"bools",
@@ -1661,6 +1672,8 @@ mod tests {
vec![
Field::new("a", DataType::Int16, true),
Field::new("b", DataType::Float64, false),
+ Field::new("c", DataType::Float32, false),
+ Field::new("d", DataType::Float16, false),
]
.into(),
),
diff --git a/parquet/src/arrow/schema/primitive.rs
b/parquet/src/arrow/schema/primitive.rs
index 7d8b6a04ee..fdc744831a 100644
--- a/parquet/src/arrow/schema/primitive.rs
+++ b/parquet/src/arrow/schema/primitive.rs
@@ -304,6 +304,16 @@ fn from_fixed_len_byte_array(
// would be incorrect if all 12 bytes of the interval are populated
Ok(DataType::Interval(IntervalUnit::DayTime))
}
+ (Some(LogicalType::Float16), _) => {
+ if type_length == 2 {
+ Ok(DataType::Float16)
+ } else {
+ Err(ParquetError::General(
+ "FLOAT16 logical type must be Fixed Length Byte Array with
length 2"
+ .to_string(),
+ ))
+ }
+ }
_ => Ok(DataType::FixedSizeBinary(type_length)),
}
}
diff --git a/parquet/src/basic.rs b/parquet/src/basic.rs
index 3c8602b802..2327e1d84b 100644
--- a/parquet/src/basic.rs
+++ b/parquet/src/basic.rs
@@ -194,6 +194,7 @@ pub enum LogicalType {
Json,
Bson,
Uuid,
+ Float16,
}
// ----------------------------------------------------------------------
@@ -505,6 +506,7 @@ impl ColumnOrder {
LogicalType::Timestamp { .. } => SortOrder::SIGNED,
LogicalType::Unknown => SortOrder::UNDEFINED,
LogicalType::Uuid => SortOrder::UNSIGNED,
+ LogicalType::Float16 => SortOrder::SIGNED,
},
// Fall back to converted type
None => Self::get_converted_sort_order(converted_type,
physical_type),
@@ -766,6 +768,7 @@ impl From<parquet::LogicalType> for LogicalType {
parquet::LogicalType::JSON(_) => LogicalType::Json,
parquet::LogicalType::BSON(_) => LogicalType::Bson,
parquet::LogicalType::UUID(_) => LogicalType::Uuid,
+ parquet::LogicalType::FLOAT16(_) => LogicalType::Float16,
}
}
}
@@ -806,6 +809,7 @@ impl From<LogicalType> for parquet::LogicalType {
LogicalType::Json =>
parquet::LogicalType::JSON(Default::default()),
LogicalType::Bson =>
parquet::LogicalType::BSON(Default::default()),
LogicalType::Uuid =>
parquet::LogicalType::UUID(Default::default()),
+ LogicalType::Float16 =>
parquet::LogicalType::FLOAT16(Default::default()),
}
}
}
@@ -853,10 +857,11 @@ impl From<Option<LogicalType>> for ConvertedType {
(64, false) => ConvertedType::UINT_64,
t => panic!("Integer type {t:?} is not supported"),
},
- LogicalType::Unknown => ConvertedType::NONE,
LogicalType::Json => ConvertedType::JSON,
LogicalType::Bson => ConvertedType::BSON,
- LogicalType::Uuid => ConvertedType::NONE,
+ LogicalType::Uuid | LogicalType::Float16 |
LogicalType::Unknown => {
+ ConvertedType::NONE
+ }
},
None => ConvertedType::NONE,
}
@@ -1102,6 +1107,7 @@ impl str::FromStr for LogicalType {
"INTERVAL" => Err(general_err!(
"Interval parquet logical type not yet supported"
)),
+ "FLOAT16" => Ok(LogicalType::Float16),
other => Err(general_err!("Invalid parquet logical type {}",
other)),
}
}
@@ -1746,6 +1752,10 @@ mod tests {
ConvertedType::from(Some(LogicalType::Enum)),
ConvertedType::ENUM
);
+ assert_eq!(
+ ConvertedType::from(Some(LogicalType::Float16)),
+ ConvertedType::NONE
+ );
assert_eq!(
ConvertedType::from(Some(LogicalType::Unknown)),
ConvertedType::NONE
@@ -2119,6 +2129,7 @@ mod tests {
is_adjusted_to_u_t_c: true,
unit: TimeUnit::NANOS(Default::default()),
},
+ LogicalType::Float16,
];
check_sort_order(signed, SortOrder::SIGNED);
diff --git a/parquet/src/column/writer/encoder.rs
b/parquet/src/column/writer/encoder.rs
index 2273ae7774..d0720dd243 100644
--- a/parquet/src/column/writer/encoder.rs
+++ b/parquet/src/column/writer/encoder.rs
@@ -16,8 +16,9 @@
// under the License.
use bytes::Bytes;
+use half::f16;
-use crate::basic::{Encoding, Type};
+use crate::basic::{Encoding, LogicalType, Type};
use crate::bloom_filter::Sbbf;
use crate::column::writer::{
compare_greater, fallback_encoding, has_dictionary_support, is_nan,
update_max, update_min,
@@ -291,7 +292,7 @@ where
{
let first = loop {
let next = iter.next()?;
- if !is_nan(next) {
+ if !is_nan(descr, next) {
break next;
}
};
@@ -299,7 +300,7 @@ where
let mut min = first;
let mut max = first;
for val in iter {
- if is_nan(val) {
+ if is_nan(descr, val) {
continue;
}
if compare_greater(descr, min, val) {
@@ -318,14 +319,14 @@ where
//
// For max, it has similar logic but will be written as 0.0
// (positive zero)
- let min = replace_zero(min, -0.0);
- let max = replace_zero(max, 0.0);
+ let min = replace_zero(min, descr, -0.0);
+ let max = replace_zero(max, descr, 0.0);
Some((min, max))
}
#[inline]
-fn replace_zero<T: ParquetValueType>(val: &T, replace: f32) -> T {
+fn replace_zero<T: ParquetValueType>(val: &T, descr: &ColumnDescriptor,
replace: f32) -> T {
match T::PHYSICAL_TYPE {
Type::FLOAT if f32::from_le_bytes(val.as_bytes().try_into().unwrap())
== 0.0 => {
T::try_from_le_slice(&f32::to_le_bytes(replace)).unwrap()
@@ -333,6 +334,12 @@ fn replace_zero<T: ParquetValueType>(val: &T, replace:
f32) -> T {
Type::DOUBLE if f64::from_le_bytes(val.as_bytes().try_into().unwrap())
== 0.0 => {
T::try_from_le_slice(&f64::to_le_bytes(replace as f64)).unwrap()
}
+ Type::FIXED_LEN_BYTE_ARRAY
+ if descr.logical_type() == Some(LogicalType::Float16)
+ && f16::from_le_bytes(val.as_bytes().try_into().unwrap()) ==
f16::NEG_ZERO =>
+ {
+
T::try_from_le_slice(&f16::to_le_bytes(f16::from_f32(replace))).unwrap()
+ }
_ => val.clone(),
}
}
diff --git a/parquet/src/column/writer/mod.rs b/parquet/src/column/writer/mod.rs
index 60db90c5d4..a917c48649 100644
--- a/parquet/src/column/writer/mod.rs
+++ b/parquet/src/column/writer/mod.rs
@@ -18,6 +18,7 @@
//! Contains column writer API.
use bytes::Bytes;
+use half::f16;
use crate::bloom_filter::Sbbf;
use crate::format::{ColumnIndex, OffsetIndex};
@@ -968,18 +969,23 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a,
E> {
}
fn update_min<T: ParquetValueType>(descr: &ColumnDescriptor, val: &T, min:
&mut Option<T>) {
- update_stat::<T, _>(val, min, |cur| compare_greater(descr, cur, val))
+ update_stat::<T, _>(descr, val, min, |cur| compare_greater(descr, cur,
val))
}
fn update_max<T: ParquetValueType>(descr: &ColumnDescriptor, val: &T, max:
&mut Option<T>) {
- update_stat::<T, _>(val, max, |cur| compare_greater(descr, val, cur))
+ update_stat::<T, _>(descr, val, max, |cur| compare_greater(descr, val,
cur))
}
#[inline]
#[allow(clippy::eq_op)]
-fn is_nan<T: ParquetValueType>(val: &T) -> bool {
+fn is_nan<T: ParquetValueType>(descr: &ColumnDescriptor, val: &T) -> bool {
match T::PHYSICAL_TYPE {
Type::FLOAT | Type::DOUBLE => val != val,
+ Type::FIXED_LEN_BYTE_ARRAY if descr.logical_type() ==
Some(LogicalType::Float16) => {
+ let val = val.as_bytes();
+ let val = f16::from_le_bytes([val[0], val[1]]);
+ val.is_nan()
+ }
_ => false,
}
}
@@ -989,11 +995,15 @@ fn is_nan<T: ParquetValueType>(val: &T) -> bool {
/// If `cur` is `None`, sets `cur` to `Some(val)`, otherwise calls
`should_update` with
/// the value of `cur`, and updates `cur` to `Some(val)` if it returns `true`
-fn update_stat<T: ParquetValueType, F>(val: &T, cur: &mut Option<T>,
should_update: F)
-where
+fn update_stat<T: ParquetValueType, F>(
+ descr: &ColumnDescriptor,
+ val: &T,
+ cur: &mut Option<T>,
+ should_update: F,
+) where
F: Fn(&T) -> bool,
{
- if is_nan(val) {
+ if is_nan(descr, val) {
return;
}
@@ -1039,6 +1049,14 @@ fn compare_greater<T: ParquetValueType>(descr:
&ColumnDescriptor, a: &T, b: &T)
};
};
+ if let Some(LogicalType::Float16) = descr.logical_type() {
+ let a = a.as_bytes();
+ let a = f16::from_le_bytes([a[0], a[1]]);
+ let b = b.as_bytes();
+ let b = f16::from_le_bytes([b[0], b[1]]);
+ return a > b;
+ }
+
a > b
}
@@ -1170,6 +1188,7 @@ fn increment_utf8(mut data: Vec<u8>) -> Option<Vec<u8>> {
mod tests {
use crate::{file::properties::DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH,
format::BoundaryOrder};
use bytes::Bytes;
+ use half::f16;
use rand::distributions::uniform::SampleUniform;
use std::sync::Arc;
@@ -2078,6 +2097,135 @@ mod tests {
}
}
+ #[test]
+ fn test_column_writer_check_float16_min_max() {
+ let input = [
+ -f16::ONE,
+ f16::from_f32(3.0),
+ -f16::from_f32(2.0),
+ f16::from_f32(2.0),
+ ]
+ .into_iter()
+ .map(|s| ByteArray::from(s).into())
+ .collect::<Vec<_>>();
+
+ let stats = float16_statistics_roundtrip(&input);
+ assert!(stats.has_min_max_set());
+ assert!(stats.is_min_max_backwards_compatible());
+ assert_eq!(stats.min(), &ByteArray::from(-f16::from_f32(2.0)));
+ assert_eq!(stats.max(), &ByteArray::from(f16::from_f32(3.0)));
+ }
+
+ #[test]
+ fn test_column_writer_check_float16_nan_middle() {
+ let input = [f16::ONE, f16::NAN, f16::ONE + f16::ONE]
+ .into_iter()
+ .map(|s| ByteArray::from(s).into())
+ .collect::<Vec<_>>();
+
+ let stats = float16_statistics_roundtrip(&input);
+ assert!(stats.has_min_max_set());
+ assert!(stats.is_min_max_backwards_compatible());
+ assert_eq!(stats.min(), &ByteArray::from(f16::ONE));
+ assert_eq!(stats.max(), &ByteArray::from(f16::ONE + f16::ONE));
+ }
+
+ #[test]
+ fn test_float16_statistics_nan_middle() {
+ let input = [f16::ONE, f16::NAN, f16::ONE + f16::ONE]
+ .into_iter()
+ .map(|s| ByteArray::from(s).into())
+ .collect::<Vec<_>>();
+
+ let stats = float16_statistics_roundtrip(&input);
+ assert!(stats.has_min_max_set());
+ assert!(stats.is_min_max_backwards_compatible());
+ assert_eq!(stats.min(), &ByteArray::from(f16::ONE));
+ assert_eq!(stats.max(), &ByteArray::from(f16::ONE + f16::ONE));
+ }
+
+ #[test]
+ fn test_float16_statistics_nan_start() {
+ let input = [f16::NAN, f16::ONE, f16::ONE + f16::ONE]
+ .into_iter()
+ .map(|s| ByteArray::from(s).into())
+ .collect::<Vec<_>>();
+
+ let stats = float16_statistics_roundtrip(&input);
+ assert!(stats.has_min_max_set());
+ assert!(stats.is_min_max_backwards_compatible());
+ assert_eq!(stats.min(), &ByteArray::from(f16::ONE));
+ assert_eq!(stats.max(), &ByteArray::from(f16::ONE + f16::ONE));
+ }
+
+ #[test]
+ fn test_float16_statistics_nan_only() {
+ let input = [f16::NAN, f16::NAN]
+ .into_iter()
+ .map(|s| ByteArray::from(s).into())
+ .collect::<Vec<_>>();
+
+ let stats = float16_statistics_roundtrip(&input);
+ assert!(!stats.has_min_max_set());
+ assert!(stats.is_min_max_backwards_compatible());
+ }
+
+ #[test]
+ fn test_float16_statistics_zero_only() {
+ let input = [f16::ZERO]
+ .into_iter()
+ .map(|s| ByteArray::from(s).into())
+ .collect::<Vec<_>>();
+
+ let stats = float16_statistics_roundtrip(&input);
+ assert!(stats.has_min_max_set());
+ assert!(stats.is_min_max_backwards_compatible());
+ assert_eq!(stats.min(), &ByteArray::from(f16::NEG_ZERO));
+ assert_eq!(stats.max(), &ByteArray::from(f16::ZERO));
+ }
+
+ #[test]
+ fn test_float16_statistics_neg_zero_only() {
+ let input = [f16::NEG_ZERO]
+ .into_iter()
+ .map(|s| ByteArray::from(s).into())
+ .collect::<Vec<_>>();
+
+ let stats = float16_statistics_roundtrip(&input);
+ assert!(stats.has_min_max_set());
+ assert!(stats.is_min_max_backwards_compatible());
+ assert_eq!(stats.min(), &ByteArray::from(f16::NEG_ZERO));
+ assert_eq!(stats.max(), &ByteArray::from(f16::ZERO));
+ }
+
+ #[test]
+ fn test_float16_statistics_zero_min() {
+ let input = [f16::ZERO, f16::ONE, f16::NAN, f16::PI]
+ .into_iter()
+ .map(|s| ByteArray::from(s).into())
+ .collect::<Vec<_>>();
+
+ let stats = float16_statistics_roundtrip(&input);
+ assert!(stats.has_min_max_set());
+ assert!(stats.is_min_max_backwards_compatible());
+ assert_eq!(stats.min(), &ByteArray::from(f16::NEG_ZERO));
+ assert_eq!(stats.max(), &ByteArray::from(f16::PI));
+ }
+
+ #[test]
+ fn test_float16_statistics_neg_zero_max() {
+ let input = [f16::NEG_ZERO, f16::NEG_ONE, f16::NAN, -f16::PI]
+ .into_iter()
+ .map(|s| ByteArray::from(s).into())
+ .collect::<Vec<_>>();
+
+ let stats = float16_statistics_roundtrip(&input);
+ assert!(stats.has_min_max_set());
+ assert!(stats.is_min_max_backwards_compatible());
+ assert_eq!(stats.min(), &ByteArray::from(-f16::PI));
+ assert_eq!(stats.max(), &ByteArray::from(f16::ZERO));
+ }
+
#[test]
fn test_float_statistics_nan_middle() {
let stats = statistics_roundtrip::<FloatType>(&[1.0, f32::NAN, 2.0]);
@@ -2850,6 +2998,50 @@ mod tests {
ColumnDescriptor::new(Arc::new(tpe), max_def_level, max_rep_level,
path)
}
+ fn float16_statistics_roundtrip(
+ values: &[FixedLenByteArray],
+ ) -> ValueStatistics<FixedLenByteArray> {
+ let page_writer = get_test_page_writer();
+ let props = Default::default();
+ let mut writer =
+
get_test_float16_column_writer::<FixedLenByteArrayType>(page_writer, 0, 0,
props);
+ writer.write_batch(values, None, None).unwrap();
+
+ let metadata = writer.close().unwrap().metadata;
+ if let Some(Statistics::FixedLenByteArray(stats)) =
metadata.statistics() {
+ stats.clone()
+ } else {
+ panic!("metadata missing statistics");
+ }
+ }
+
+ fn get_test_float16_column_writer<T: DataType>(
+ page_writer: Box<dyn PageWriter>,
+ max_def_level: i16,
+ max_rep_level: i16,
+ props: WriterPropertiesPtr,
+ ) -> ColumnWriterImpl<'static, T> {
+ let descr = Arc::new(get_test_float16_column_descr::<T>(
+ max_def_level,
+ max_rep_level,
+ ));
+ let column_writer = get_column_writer(descr, props, page_writer);
+ get_typed_column_writer::<T>(column_writer)
+ }
+
+ fn get_test_float16_column_descr<T: DataType>(
+ max_def_level: i16,
+ max_rep_level: i16,
+ ) -> ColumnDescriptor {
+ let path = ColumnPath::from("col");
+ let tpe = SchemaType::primitive_type_builder("col",
T::get_physical_type())
+ .with_length(2)
+ .with_logical_type(Some(LogicalType::Float16))
+ .build()
+ .unwrap();
+ ColumnDescriptor::new(Arc::new(tpe), max_def_level, max_rep_level,
path)
+ }
+
/// Returns column writer for UINT32 Column provided as ConvertedType only
fn get_test_unsigned_int_given_as_converted_column_writer<'a, T: DataType>(
page_writer: Box<dyn PageWriter + 'a>,
diff --git a/parquet/src/data_type.rs b/parquet/src/data_type.rs
index b895c25070..86da7a3ace 100644
--- a/parquet/src/data_type.rs
+++ b/parquet/src/data_type.rs
@@ -18,6 +18,7 @@
//! Data types that connect Parquet physical types with their Rust-specific
//! representations.
use bytes::Bytes;
+use half::f16;
use std::cmp::Ordering;
use std::fmt;
use std::mem;
@@ -225,6 +226,12 @@ impl From<Bytes> for ByteArray {
}
}
+impl From<f16> for ByteArray {
+ fn from(value: f16) -> Self {
+ Self::from(value.to_le_bytes().as_slice())
+ }
+}
+
impl PartialEq for ByteArray {
fn eq(&self, other: &ByteArray) -> bool {
match (&self.data, &other.data) {
diff --git a/parquet/src/file/statistics.rs b/parquet/src/file/statistics.rs
index b36e37a80c..345fe7dd26 100644
--- a/parquet/src/file/statistics.rs
+++ b/parquet/src/file/statistics.rs
@@ -243,6 +243,8 @@ pub fn to_thrift(stats: Option<&Statistics>) ->
Option<TStatistics> {
distinct_count: stats.distinct_count().map(|value| value as i64),
max_value: None,
min_value: None,
+ is_max_value_exact: None,
+ is_min_value_exact: None,
};
// Get min/max if set.
@@ -607,6 +609,8 @@ mod tests {
distinct_count: None,
max_value: None,
min_value: None,
+ is_max_value_exact: None,
+ is_min_value_exact: None,
};
from_thrift(Type::INT32, Some(thrift_stats)).unwrap();
diff --git a/parquet/src/format.rs b/parquet/src/format.rs
index 46adc39e64..4700b05dc2 100644
--- a/parquet/src/format.rs
+++ b/parquet/src/format.rs
@@ -657,16 +657,26 @@ pub struct Statistics {
pub null_count: Option<i64>,
/// count of distinct values occurring
pub distinct_count: Option<i64>,
- /// Min and max values for the column, determined by its ColumnOrder.
+ /// Lower and upper bound values for the column, determined by its
ColumnOrder.
+ ///
+ /// These may be the actual minimum and maximum values found on a page or
column
+ /// chunk, but can also be (more compact) values that do not exist on a page
or
+ /// column chunk. For example, instead of storing "Blart Versenwald III", a
writer
+ /// may set min_value="B", max_value="C". Such more compact values must
still be
+ /// valid values within the column's logical type.
///
/// Values are encoded using PLAIN encoding, except that variable-length byte
/// arrays do not include a length prefix.
pub max_value: Option<Vec<u8>>,
pub min_value: Option<Vec<u8>>,
+ /// If true, max_value is the actual maximum value for a column
+ pub is_max_value_exact: Option<bool>,
+ /// If true, min_value is the actual minimum value for a column
+ pub is_min_value_exact: Option<bool>,
}
impl Statistics {
- pub fn new<F1, F2, F3, F4, F5, F6>(max: F1, min: F2, null_count: F3,
distinct_count: F4, max_value: F5, min_value: F6) -> Statistics where F1:
Into<Option<Vec<u8>>>, F2: Into<Option<Vec<u8>>>, F3: Into<Option<i64>>, F4:
Into<Option<i64>>, F5: Into<Option<Vec<u8>>>, F6: Into<Option<Vec<u8>>> {
+ pub fn new<F1, F2, F3, F4, F5, F6, F7, F8>(max: F1, min: F2, null_count: F3,
distinct_count: F4, max_value: F5, min_value: F6, is_max_value_exact: F7,
is_min_value_exact: F8) -> Statistics where F1: Into<Option<Vec<u8>>>, F2:
Into<Option<Vec<u8>>>, F3: Into<Option<i64>>, F4: Into<Option<i64>>, F5:
Into<Option<Vec<u8>>>, F6: Into<Option<Vec<u8>>>, F7: Into<Option<bool>>, F8:
Into<Option<bool>> {
Statistics {
max: max.into(),
min: min.into(),
@@ -674,6 +684,8 @@ impl Statistics {
distinct_count: distinct_count.into(),
max_value: max_value.into(),
min_value: min_value.into(),
+ is_max_value_exact: is_max_value_exact.into(),
+ is_min_value_exact: is_min_value_exact.into(),
}
}
}
@@ -687,6 +699,8 @@ impl crate::thrift::TSerializable for Statistics {
let mut f_4: Option<i64> = None;
let mut f_5: Option<Vec<u8>> = None;
let mut f_6: Option<Vec<u8>> = None;
+ let mut f_7: Option<bool> = None;
+ let mut f_8: Option<bool> = None;
loop {
let field_ident = i_prot.read_field_begin()?;
if field_ident.field_type == TType::Stop {
@@ -718,6 +732,14 @@ impl crate::thrift::TSerializable for Statistics {
let val = i_prot.read_bytes()?;
f_6 = Some(val);
},
+ 7 => {
+ let val = i_prot.read_bool()?;
+ f_7 = Some(val);
+ },
+ 8 => {
+ let val = i_prot.read_bool()?;
+ f_8 = Some(val);
+ },
_ => {
i_prot.skip(field_ident.field_type)?;
},
@@ -732,6 +754,8 @@ impl crate::thrift::TSerializable for Statistics {
distinct_count: f_4,
max_value: f_5,
min_value: f_6,
+ is_max_value_exact: f_7,
+ is_min_value_exact: f_8,
};
Ok(ret)
}
@@ -768,6 +792,16 @@ impl crate::thrift::TSerializable for Statistics {
o_prot.write_bytes(fld_var)?;
o_prot.write_field_end()?
}
+ if let Some(fld_var) = self.is_max_value_exact {
+ o_prot.write_field_begin(&TFieldIdentifier::new("is_max_value_exact",
TType::Bool, 7))?;
+ o_prot.write_bool(fld_var)?;
+ o_prot.write_field_end()?
+ }
+ if let Some(fld_var) = self.is_min_value_exact {
+ o_prot.write_field_begin(&TFieldIdentifier::new("is_min_value_exact",
TType::Bool, 8))?;
+ o_prot.write_bool(fld_var)?;
+ o_prot.write_field_end()?
+ }
o_prot.write_field_stop()?;
o_prot.write_struct_end()
}
@@ -996,6 +1030,43 @@ impl crate::thrift::TSerializable for DateType {
}
}
+//
+// Float16Type
+//
+
+#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
+pub struct Float16Type {
+}
+
+impl Float16Type {
+ pub fn new() -> Float16Type {
+ Float16Type {}
+ }
+}
+
+impl crate::thrift::TSerializable for Float16Type {
+ fn read_from_in_protocol<T: TInputProtocol>(i_prot: &mut T) ->
thrift::Result<Float16Type> {
+ i_prot.read_struct_begin()?;
+ loop {
+ let field_ident = i_prot.read_field_begin()?;
+ if field_ident.field_type == TType::Stop {
+ break;
+ }
+ i_prot.skip(field_ident.field_type)?;
+ i_prot.read_field_end()?;
+ }
+ i_prot.read_struct_end()?;
+ let ret = Float16Type {};
+ Ok(ret)
+ }
+ fn write_to_out_protocol<T: TOutputProtocol>(&self, o_prot: &mut T) ->
thrift::Result<()> {
+ let struct_ident = TStructIdentifier::new("Float16Type");
+ o_prot.write_struct_begin(&struct_ident)?;
+ o_prot.write_field_stop()?;
+ o_prot.write_struct_end()
+ }
+}
+
//
// NullType
//
@@ -1640,6 +1711,7 @@ pub enum LogicalType {
JSON(JsonType),
BSON(BsonType),
UUID(UUIDType),
+ FLOAT16(Float16Type),
}
impl crate::thrift::TSerializable for LogicalType {
@@ -1745,6 +1817,13 @@ impl crate::thrift::TSerializable for LogicalType {
}
received_field_count += 1;
},
+ 15 => {
+ let val = Float16Type::read_from_in_protocol(i_prot)?;
+ if ret.is_none() {
+ ret = Some(LogicalType::FLOAT16(val));
+ }
+ received_field_count += 1;
+ },
_ => {
i_prot.skip(field_ident.field_type)?;
received_field_count += 1;
@@ -1844,6 +1923,11 @@ impl crate::thrift::TSerializable for LogicalType {
f.write_to_out_protocol(o_prot)?;
o_prot.write_field_end()?;
},
+ LogicalType::FLOAT16(ref f) => {
+ o_prot.write_field_begin(&TFieldIdentifier::new("FLOAT16",
TType::Struct, 15))?;
+ f.write_to_out_protocol(o_prot)?;
+ o_prot.write_field_end()?;
+ },
}
o_prot.write_field_stop()?;
o_prot.write_struct_end()
diff --git a/parquet/src/record/api.rs b/parquet/src/record/api.rs
index c7a0b09c37..e4f473562e 100644
--- a/parquet/src/record/api.rs
+++ b/parquet/src/record/api.rs
@@ -20,9 +20,11 @@
use std::fmt;
use chrono::{TimeZone, Utc};
+use half::f16;
+use num::traits::Float;
use num_bigint::{BigInt, Sign};
-use crate::basic::{ConvertedType, Type as PhysicalType};
+use crate::basic::{ConvertedType, LogicalType, Type as PhysicalType};
use crate::data_type::{ByteArray, Decimal, Int96};
use crate::errors::{ParquetError, Result};
use crate::schema::types::ColumnDescPtr;
@@ -121,6 +123,7 @@ pub trait RowAccessor {
fn get_ushort(&self, i: usize) -> Result<u16>;
fn get_uint(&self, i: usize) -> Result<u32>;
fn get_ulong(&self, i: usize) -> Result<u64>;
+ fn get_float16(&self, i: usize) -> Result<f16>;
fn get_float(&self, i: usize) -> Result<f32>;
fn get_double(&self, i: usize) -> Result<f64>;
fn get_timestamp_millis(&self, i: usize) -> Result<i64>;
@@ -215,6 +218,8 @@ impl RowAccessor for Row {
row_primitive_accessor!(get_ulong, ULong, u64);
+ row_primitive_accessor!(get_float16, Float16, f16);
+
row_primitive_accessor!(get_float, Float, f32);
row_primitive_accessor!(get_double, Double, f64);
@@ -293,6 +298,7 @@ pub trait ListAccessor {
fn get_ushort(&self, i: usize) -> Result<u16>;
fn get_uint(&self, i: usize) -> Result<u32>;
fn get_ulong(&self, i: usize) -> Result<u64>;
+ fn get_float16(&self, i: usize) -> Result<f16>;
fn get_float(&self, i: usize) -> Result<f32>;
fn get_double(&self, i: usize) -> Result<f64>;
fn get_timestamp_millis(&self, i: usize) -> Result<i64>;
@@ -358,6 +364,8 @@ impl ListAccessor for List {
list_primitive_accessor!(get_ulong, ULong, u64);
+ list_primitive_accessor!(get_float16, Float16, f16);
+
list_primitive_accessor!(get_float, Float, f32);
list_primitive_accessor!(get_double, Double, f64);
@@ -449,6 +457,8 @@ impl<'a> ListAccessor for MapList<'a> {
map_list_primitive_accessor!(get_ulong, ULong, u64);
+ map_list_primitive_accessor!(get_float16, Float16, f16);
+
map_list_primitive_accessor!(get_float, Float, f32);
map_list_primitive_accessor!(get_double, Double, f64);
@@ -510,6 +520,8 @@ pub enum Field {
UInt(u32),
// Unsigned integer UINT_64.
ULong(u64),
+ /// IEEE 16-bit floating point value.
+ Float16(f16),
/// IEEE 32-bit floating point value.
Float(f32),
/// IEEE 64-bit floating point value.
@@ -552,6 +564,7 @@ impl Field {
Field::UShort(_) => "UShort",
Field::UInt(_) => "UInt",
Field::ULong(_) => "ULong",
+ Field::Float16(_) => "Float16",
Field::Float(_) => "Float",
Field::Double(_) => "Double",
Field::Decimal(_) => "Decimal",
@@ -636,8 +649,8 @@ impl Field {
Field::Double(value)
}
- /// Converts Parquet BYTE_ARRAY type with converted type into either UTF8
string or
- /// array of bytes.
+ /// Converts Parquet BYTE_ARRAY type with converted type into a UTF8
+ /// string, decimal, float16, or an array of bytes.
#[inline]
pub fn convert_byte_array(descr: &ColumnDescPtr, value: ByteArray) ->
Result<Self> {
let field = match descr.physical_type() {
@@ -666,6 +679,16 @@ impl Field {
descr.type_precision(),
descr.type_scale(),
)),
+ ConvertedType::NONE if descr.logical_type() ==
Some(LogicalType::Float16) => {
+ if value.len() != 2 {
+ return Err(general_err!(
+ "Error reading FIXED_LEN_BYTE_ARRAY as FLOAT16.
Length must be 2, got {}",
+ value.len()
+ ));
+ }
+ let bytes = [value.data()[0], value.data()[1]];
+ Field::Float16(f16::from_le_bytes(bytes))
+ }
ConvertedType::NONE => Field::Bytes(value),
_ => nyi!(descr, value),
},
@@ -690,6 +713,9 @@ impl Field {
Field::UShort(n) => Value::Number(serde_json::Number::from(*n)),
Field::UInt(n) => Value::Number(serde_json::Number::from(*n)),
Field::ULong(n) => Value::Number(serde_json::Number::from(*n)),
+ Field::Float16(n) => serde_json::Number::from_f64(f64::from(*n))
+ .map(Value::Number)
+ .unwrap_or(Value::Null),
Field::Float(n) => serde_json::Number::from_f64(f64::from(*n))
.map(Value::Number)
.unwrap_or(Value::Null),
@@ -736,6 +762,15 @@ impl fmt::Display for Field {
Field::UShort(value) => write!(f, "{value}"),
Field::UInt(value) => write!(f, "{value}"),
Field::ULong(value) => write!(f, "{value}"),
+ Field::Float16(value) => {
+ if !value.is_finite() {
+ write!(f, "{value}")
+ } else if value.trunc() == value {
+ write!(f, "{value}.0")
+ } else {
+ write!(f, "{value}")
+ }
+ }
Field::Float(value) => {
if !(1e-15..=1e19).contains(&value) {
write!(f, "{value:E}")
@@ -1069,6 +1104,24 @@ mod tests {
Field::Decimal(Decimal::from_bytes(value, 17, 5))
);
+ // FLOAT16
+ let descr = {
+ let tpe = PrimitiveTypeBuilder::new("col",
PhysicalType::FIXED_LEN_BYTE_ARRAY)
+ .with_logical_type(Some(LogicalType::Float16))
+ .with_length(2)
+ .build()
+ .unwrap();
+ Arc::new(ColumnDescriptor::new(
+ Arc::new(tpe),
+ 0,
+ 0,
+ ColumnPath::from("col"),
+ ))
+ };
+ let value = ByteArray::from(f16::PI);
+ let row = Field::convert_byte_array(&descr, value.clone());
+ assert_eq!(row.unwrap(), Field::Float16(f16::PI));
+
// NONE (FIXED_LEN_BYTE_ARRAY)
let descr = make_column_descr![
PhysicalType::FIXED_LEN_BYTE_ARRAY,
@@ -1145,6 +1198,18 @@ mod tests {
check_datetime_conversion(2014, 11, 28, 21, 15, 12);
}
+ #[test]
+ fn test_convert_float16_to_string() {
+ assert_eq!(format!("{}", Field::Float16(f16::ONE)), "1.0");
+ assert_eq!(format!("{}", Field::Float16(f16::PI)), "3.140625");
+ assert_eq!(format!("{}", Field::Float16(f16::MAX)), "65504.0");
+ assert_eq!(format!("{}", Field::Float16(f16::NAN)), "NaN");
+ assert_eq!(format!("{}", Field::Float16(f16::INFINITY)), "inf");
+ assert_eq!(format!("{}", Field::Float16(f16::NEG_INFINITY)), "-inf");
+ assert_eq!(format!("{}", Field::Float16(f16::ZERO)), "0.0");
+ assert_eq!(format!("{}", Field::Float16(f16::NEG_ZERO)), "-0.0");
+ }
+
#[test]
fn test_convert_float_to_string() {
assert_eq!(format!("{}", Field::Float(1.0)), "1.0");
@@ -1218,6 +1283,7 @@ mod tests {
assert_eq!(format!("{}", Field::UShort(2)), "2");
assert_eq!(format!("{}", Field::UInt(3)), "3");
assert_eq!(format!("{}", Field::ULong(4)), "4");
+ assert_eq!(format!("{}", Field::Float16(f16::E)), "2.71875");
assert_eq!(format!("{}", Field::Float(5.0)), "5.0");
assert_eq!(format!("{}", Field::Float(5.1234)), "5.1234");
assert_eq!(format!("{}", Field::Double(6.0)), "6.0");
@@ -1284,6 +1350,7 @@ mod tests {
assert!(Field::UShort(2).is_primitive());
assert!(Field::UInt(3).is_primitive());
assert!(Field::ULong(4).is_primitive());
+ assert!(Field::Float16(f16::E).is_primitive());
assert!(Field::Float(5.0).is_primitive());
assert!(Field::Float(5.1234).is_primitive());
assert!(Field::Double(6.0).is_primitive());
@@ -1344,6 +1411,7 @@ mod tests {
("15".to_string(), Field::TimestampMillis(1262391174000)),
("16".to_string(), Field::TimestampMicros(1262391174000000)),
("17".to_string(), Field::Decimal(Decimal::from_i32(4, 7, 2))),
+ ("18".to_string(), Field::Float16(f16::PI)),
]);
assert_eq!("null", format!("{}", row.fmt(0)));
@@ -1370,6 +1438,7 @@ mod tests {
format!("{}", row.fmt(16))
);
assert_eq!("0.04", format!("{}", row.fmt(17)));
+ assert_eq!("3.140625", format!("{}", row.fmt(18)));
}
#[test]
@@ -1429,6 +1498,7 @@ mod tests {
Field::Bytes(ByteArray::from(vec![1, 2, 3, 4, 5])),
),
("o".to_string(), Field::Decimal(Decimal::from_i32(4, 7, 2))),
+ ("p".to_string(), Field::Float16(f16::from_f32(9.1))),
]);
assert!(!row.get_bool(1).unwrap());
@@ -1445,6 +1515,7 @@ mod tests {
assert_eq!("abc", row.get_string(12).unwrap());
assert_eq!(5, row.get_bytes(13).unwrap().len());
assert_eq!(7, row.get_decimal(14).unwrap().precision());
+ assert!((f16::from_f32(9.1) - row.get_float16(15).unwrap()).abs() <
f16::EPSILON);
}
#[test]
@@ -1469,6 +1540,7 @@ mod tests {
Field::Bytes(ByteArray::from(vec![1, 2, 3, 4, 5])),
),
("o".to_string(), Field::Decimal(Decimal::from_i32(4, 7, 2))),
+ ("p".to_string(), Field::Float16(f16::from_f32(9.1))),
]);
for i in 0..row.len() {
@@ -1583,6 +1655,9 @@ mod tests {
let list = make_list(vec![Field::ULong(6), Field::ULong(7)]);
assert_eq!(7, list.get_ulong(1).unwrap());
+ let list = make_list(vec![Field::Float16(f16::PI)]);
+ assert!((f16::PI - list.get_float16(0).unwrap()).abs() < f16::EPSILON);
+
let list = make_list(vec![
Field::Float(8.1),
Field::Float(9.2),
@@ -1633,6 +1708,9 @@ mod tests {
let list = make_list(vec![Field::ULong(6), Field::ULong(7)]);
assert!(list.get_float(1).is_err());
+ let list = make_list(vec![Field::Float16(f16::PI)]);
+ assert!(list.get_string(0).is_err());
+
let list = make_list(vec![
Field::Float(8.1),
Field::Float(9.2),
@@ -1768,6 +1846,10 @@ mod tests {
Field::ULong(4).to_json_value(),
Value::Number(serde_json::Number::from(4))
);
+ assert_eq!(
+ Field::Float16(f16::from_f32(5.0)).to_json_value(),
+ Value::Number(serde_json::Number::from_f64(5.0).unwrap())
+ );
assert_eq!(
Field::Float(5.0).to_json_value(),
Value::Number(serde_json::Number::from_f64(5.0).unwrap())
diff --git a/parquet/src/schema/parser.rs b/parquet/src/schema/parser.rs
index 5e213e3bb9..dcef11aa66 100644
--- a/parquet/src/schema/parser.rs
+++ b/parquet/src/schema/parser.rs
@@ -823,6 +823,7 @@ mod tests {
message root {
optional fixed_len_byte_array(5) f1 (DECIMAL(9, 3));
optional fixed_len_byte_array (16) f2 (DECIMAL (38, 18));
+ optional fixed_len_byte_array (2) f3 (FLOAT16);
}
";
let message = parse(schema).unwrap();
@@ -855,6 +856,13 @@ mod tests {
.build()
.unwrap(),
),
+ Arc::new(
+ Type::primitive_type_builder("f3",
PhysicalType::FIXED_LEN_BYTE_ARRAY)
+ .with_logical_type(Some(LogicalType::Float16))
+ .with_length(2)
+ .build()
+ .unwrap(),
+ ),
])
.build()
.unwrap();
diff --git a/parquet/src/schema/printer.rs b/parquet/src/schema/printer.rs
index fe4757d41a..2dec8a5be9 100644
--- a/parquet/src/schema/printer.rs
+++ b/parquet/src/schema/printer.rs
@@ -270,6 +270,7 @@ fn print_logical_and_converted(
LogicalType::Enum => "ENUM".to_string(),
LogicalType::List => "LIST".to_string(),
LogicalType::Map => "MAP".to_string(),
+ LogicalType::Float16 => "FLOAT16".to_string(),
LogicalType::Unknown => "UNKNOWN".to_string(),
},
None => {
@@ -667,6 +668,15 @@ mod tests {
.unwrap(),
"OPTIONAL FIXED_LEN_BYTE_ARRAY (9) decimal (DECIMAL(19,4));",
),
+ (
+ Type::primitive_type_builder("float16",
PhysicalType::FIXED_LEN_BYTE_ARRAY)
+ .with_logical_type(Some(LogicalType::Float16))
+ .with_length(2)
+ .with_repetition(Repetition::REQUIRED)
+ .build()
+ .unwrap(),
+ "REQUIRED FIXED_LEN_BYTE_ARRAY (2) float16 (FLOAT16);",
+ ),
];
types_and_strings.into_iter().for_each(|(field, expected)| {
diff --git a/parquet/src/schema/types.rs b/parquet/src/schema/types.rs
index 11c7354209..2f36deffba 100644
--- a/parquet/src/schema/types.rs
+++ b/parquet/src/schema/types.rs
@@ -356,6 +356,14 @@ impl<'a> PrimitiveTypeBuilder<'a> {
(LogicalType::Json, PhysicalType::BYTE_ARRAY) => {}
(LogicalType::Bson, PhysicalType::BYTE_ARRAY) => {}
(LogicalType::Uuid, PhysicalType::FIXED_LEN_BYTE_ARRAY) =>
{}
+ (LogicalType::Float16, PhysicalType::FIXED_LEN_BYTE_ARRAY)
+ if self.length == 2 => {}
+ (LogicalType::Float16, PhysicalType::FIXED_LEN_BYTE_ARRAY)
=> {
+ return Err(general_err!(
+ "FLOAT16 cannot annotate field '{}' because it is
not a FIXED_LEN_BYTE_ARRAY(2) field",
+ self.name
+ ))
+ }
(a, b) => {
return Err(general_err!(
"Cannot annotate {:?} from {} for field '{}'",
@@ -1504,6 +1512,41 @@ mod tests {
"Parquet error: Invalid FIXED_LEN_BYTE_ARRAY length: -1 for
field 'foo'"
);
}
+
+ result = Type::primitive_type_builder("foo",
PhysicalType::FIXED_LEN_BYTE_ARRAY)
+ .with_repetition(Repetition::REQUIRED)
+ .with_logical_type(Some(LogicalType::Float16))
+ .with_length(2)
+ .build();
+ assert!(result.is_ok());
+
+ // Can't be other than FIXED_LEN_BYTE_ARRAY for physical type
+ result = Type::primitive_type_builder("foo", PhysicalType::FLOAT)
+ .with_repetition(Repetition::REQUIRED)
+ .with_logical_type(Some(LogicalType::Float16))
+ .with_length(2)
+ .build();
+ assert!(result.is_err());
+ if let Err(e) = result {
+ assert_eq!(
+ format!("{e}"),
+ "Parquet error: Cannot annotate Float16 from FLOAT for field
'foo'"
+ );
+ }
+
+ // Must have length 2
+ result = Type::primitive_type_builder("foo",
PhysicalType::FIXED_LEN_BYTE_ARRAY)
+ .with_repetition(Repetition::REQUIRED)
+ .with_logical_type(Some(LogicalType::Float16))
+ .with_length(4)
+ .build();
+ assert!(result.is_err());
+ if let Err(e) = result {
+ assert_eq!(
+ format!("{e}"),
+ "Parquet error: FLOAT16 cannot annotate field 'foo' because it
is not a FIXED_LEN_BYTE_ARRAY(2) field"
+ );
+ }
}
#[test]
@@ -1981,6 +2024,7 @@ mod tests {
let message_type = "
message conversions {
REQUIRED INT64 id;
+ OPTIONAL FIXED_LEN_BYTE_ARRAY (2) f16 (FLOAT16);
OPTIONAL group int_array_Array (LIST) {
REPEATED group list {
OPTIONAL group element (LIST) {