This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 07024f6a1 Treat DecimalArray as PrimitiveArray in row format (#2866)
07024f6a1 is described below

commit 07024f6a16b870fda81cba5779b8817b20386ebf
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Wed Oct 19 08:10:58 2022 +1300

    Treat DecimalArray as PrimitiveArray in row format (#2866)
---
 arrow-buffer/src/bigint.rs  |  20 +++++++++
 arrow/src/row/dictionary.rs |  21 ++++------
 arrow/src/row/fixed.rs      |  47 +--------------------
 arrow/src/row/mod.rs        | 100 ++++++++++++++++++++++++++++++++++----------
 4 files changed, 108 insertions(+), 80 deletions(-)

diff --git a/arrow-buffer/src/bigint.rs b/arrow-buffer/src/bigint.rs
index 7873064b4..3518b85e4 100644
--- a/arrow-buffer/src/bigint.rs
+++ b/arrow-buffer/src/bigint.rs
@@ -86,6 +86,15 @@ impl i256 {
         }
     }
 
+    /// Create an integer value from its representation as a byte array in 
little-endian.
+    #[inline]
+    pub fn from_be_bytes(b: [u8; 32]) -> Self {
+        Self {
+            high: i128::from_be_bytes(b[0..16].try_into().unwrap()),
+            low: u128::from_be_bytes(b[16..32].try_into().unwrap()),
+        }
+    }
+
     pub fn from_i128(v: i128) -> Self {
         let mut bytes = if num::Signed::is_negative(&v) {
             [255_u8; 32]
@@ -130,6 +139,17 @@ impl i256 {
         t
     }
 
+    /// Return the memory representation of this integer as a byte array in 
big-endian byte order.
+    #[inline]
+    pub fn to_be_bytes(self) -> [u8; 32] {
+        let mut t = [0; 32];
+        let t_low: &mut [u8; 16] = (&mut t[0..16]).try_into().unwrap();
+        *t_low = self.high.to_be_bytes();
+        let t_high: &mut [u8; 16] = (&mut t[16..32]).try_into().unwrap();
+        *t_high = self.low.to_be_bytes();
+        t
+    }
+
     /// Create an i256 from the provided [`BigInt`] returning a bool indicating
     /// if overflow occurred
     fn from_bigint_with_overflow(v: BigInt) -> (Self, bool) {
diff --git a/arrow/src/row/dictionary.rs b/arrow/src/row/dictionary.rs
index b06688224..1ec7c2a21 100644
--- a/arrow/src/row/dictionary.rs
+++ b/arrow/src/row/dictionary.rs
@@ -16,7 +16,7 @@
 // under the License.
 
 use crate::compute::SortOptions;
-use crate::row::fixed::{FixedLengthEncoding, FromSlice, RawDecimal};
+use crate::row::fixed::{FixedLengthEncoding, FromSlice};
 use crate::row::interner::{Interned, OrderPreservingInterner};
 use crate::row::{null_sentinel, Rows};
 use arrow_array::builder::*;
@@ -173,12 +173,8 @@ pub unsafe fn decode_dictionary<K: ArrowDictionaryKeyType>(
         &value_type => (decode_primitive_helper, values),
         DataType::Null => NullArray::new(values.len()).into_data(),
         DataType::Boolean => decode_bool(&values),
-        DataType::Decimal128(p, s) => {
-            decode_decimal::<16, Decimal128Type>(&values, *p, *s)
-        }
-        DataType::Decimal256(p, s) => {
-            decode_decimal::<32, Decimal256Type>(&values, *p, *s)
-        }
+        DataType::Decimal128(p, s) => 
decode_decimal::<Decimal128Type>(&values, *p, *s),
+        DataType::Decimal256(p, s) => 
decode_decimal::<Decimal256Type>(&values, *p, *s),
         DataType::Utf8 => decode_string::<i32>(&values),
         DataType::LargeUtf8 => decode_string::<i64>(&values),
         DataType::Binary => decode_binary::<i32>(&values),
@@ -279,10 +275,9 @@ where
 }
 
 /// Decodes a `DecimalArray` from dictionary values
-fn decode_decimal<const N: usize, T: DecimalType>(
-    values: &[&[u8]],
-    precision: u8,
-    scale: u8,
-) -> ArrayData {
-    decode_fixed::<RawDecimal<N>>(values, T::TYPE_CONSTRUCTOR(precision, 
scale))
+fn decode_decimal<T: DecimalType>(values: &[&[u8]], precision: u8, scale: u8) 
-> ArrayData
+where
+    T::Native: FixedLengthEncoding,
+{
+    decode_fixed::<T::Native>(values, T::TYPE_CONSTRUCTOR(precision, scale))
 }
diff --git a/arrow/src/row/fixed.rs b/arrow/src/row/fixed.rs
index ec7afd8e3..d5935cfb6 100644
--- a/arrow/src/row/fixed.rs
+++ b/arrow/src/row/fixed.rs
@@ -19,9 +19,8 @@ use crate::array::PrimitiveArray;
 use crate::compute::SortOptions;
 use crate::datatypes::ArrowPrimitiveType;
 use crate::row::{null_sentinel, Rows};
-use arrow_array::types::DecimalType;
 use arrow_array::BooleanArray;
-use arrow_buffer::{bit_util, MutableBuffer, ToByteSlice};
+use arrow_buffer::{bit_util, i256, MutableBuffer, ToByteSlice};
 use arrow_data::{ArrayData, ArrayDataBuilder};
 use arrow_schema::DataType;
 use half::f16;
@@ -91,6 +90,7 @@ encode_signed!(2, i16);
 encode_signed!(4, i32);
 encode_signed!(8, i64);
 encode_signed!(16, i128);
+encode_signed!(32, i256);
 
 macro_rules! encode_unsigned {
     ($n:expr, $t:ty) => {
@@ -164,38 +164,6 @@ impl FixedLengthEncoding for f64 {
     }
 }
 
-pub type RawDecimal128 = RawDecimal<16>;
-pub type RawDecimal256 = RawDecimal<32>;
-
-/// The raw bytes of a decimal
-#[derive(Copy, Clone)]
-pub struct RawDecimal<const N: usize>(pub [u8; N]);
-
-impl<const N: usize> ToByteSlice for RawDecimal<N> {
-    fn to_byte_slice(&self) -> &[u8] {
-        &self.0
-    }
-}
-
-impl<const N: usize> FixedLengthEncoding for RawDecimal<N> {
-    type Encoded = [u8; N];
-
-    fn encode(self) -> [u8; N] {
-        let mut val = self.0;
-        // Convert to big endian representation
-        val.reverse();
-        // Toggle top "sign" bit to ensure consistent sort order
-        val[0] ^= 0x80;
-        val
-    }
-
-    fn decode(mut encoded: Self::Encoded) -> Self {
-        encoded[0] ^= 0x80;
-        encoded.reverse();
-        Self(encoded)
-    }
-}
-
 /// Returns the total encoded length (including null byte) for a value of type 
`T::Native`
 pub const fn encoded_len<T>(_col: &PrimitiveArray<T>) -> usize
 where
@@ -354,17 +322,6 @@ fn decode_fixed<T: FixedLengthEncoding + ToByteSlice>(
     unsafe { builder.build_unchecked() }
 }
 
-/// Decodes a `DecimalArray` from rows
-pub fn decode_decimal<const N: usize, T: DecimalType + ArrowPrimitiveType>(
-    rows: &mut [&[u8]],
-    options: SortOptions,
-    precision: u8,
-    scale: u8,
-) -> PrimitiveArray<T> {
-    decode_fixed::<RawDecimal<N>>(rows, T::TYPE_CONSTRUCTOR(precision, scale), 
options)
-        .into()
-}
-
 /// Decodes a `PrimitiveArray` from rows
 pub fn decode_primitive<T: ArrowPrimitiveType>(
     rows: &mut [&[u8]],
diff --git a/arrow/src/row/mod.rs b/arrow/src/row/mod.rs
index 77c70a5fd..c3aa9ea4c 100644
--- a/arrow/src/row/mod.rs
+++ b/arrow/src/row/mod.rs
@@ -84,6 +84,7 @@ use std::sync::Arc;
 
 use arrow_array::cast::*;
 use arrow_array::*;
+use arrow_buffer::i256;
 
 use crate::compute::SortOptions;
 use crate::datatypes::*;
@@ -91,10 +92,7 @@ use crate::error::{ArrowError, Result};
 use crate::row::dictionary::{
     compute_dictionary_mapping, decode_dictionary, encode_dictionary,
 };
-use crate::row::fixed::{
-    decode_bool, decode_decimal, decode_primitive, RawDecimal, RawDecimal128,
-    RawDecimal256,
-};
+use crate::row::fixed::{decode_bool, decode_primitive};
 use crate::row::interner::OrderPreservingInterner;
 use crate::row::variable::{decode_binary, decode_string};
 use crate::{downcast_dictionary_array, downcast_primitive_array};
@@ -488,8 +486,8 @@ fn new_empty_rows(
             array => lengths.iter_mut().for_each(|x| *x += 
fixed::encoded_len(array)),
             DataType::Null => {},
             DataType::Boolean => lengths.iter_mut().for_each(|x| *x += 
bool::ENCODED_LEN),
-            DataType::Decimal128(_, _) => lengths.iter_mut().for_each(|x| *x 
+= RawDecimal128::ENCODED_LEN),
-            DataType::Decimal256(_, _) => lengths.iter_mut().for_each(|x| *x 
+= RawDecimal256::ENCODED_LEN),
+            DataType::Decimal128(_, _) => lengths.iter_mut().for_each(|x| *x 
+= i128::ENCODED_LEN),
+            DataType::Decimal256(_, _) => lengths.iter_mut().for_each(|x| *x 
+= i256::ENCODED_LEN),
             DataType::Binary => as_generic_binary_array::<i32>(array)
                 .iter()
                 .zip(lengths.iter_mut())
@@ -571,24 +569,20 @@ fn encode_column(
         DataType::Null => {}
         DataType::Boolean => fixed::encode(out, as_boolean_array(column), 
opts),
         DataType::Decimal128(_, _) => {
-            let iter = column
+            let column = column
                 .as_any()
                 .downcast_ref::<Decimal128Array>()
-                .unwrap()
-                .into_iter()
-                .map(|x| x.map(|x| RawDecimal(x.to_le_bytes())));
+                .unwrap();
 
-            fixed::encode(out, iter, opts)
+            fixed::encode(out, column, opts)
         },
         DataType::Decimal256(_, _) => {
-            let iter = column
+            let column = column
                 .as_any()
                 .downcast_ref::<Decimal256Array>()
-                .unwrap()
-                .into_iter()
-                .map(|x| x.map(|x| RawDecimal(x.to_le_bytes())));
+                .unwrap();
 
-            fixed::encode(out, iter, opts)
+            fixed::encode(out, column, opts)
         },
         DataType::Binary => {
             variable::encode(out, 
as_generic_binary_array::<i32>(column).iter(), opts)
@@ -641,12 +635,16 @@ unsafe fn decode_column(
         DataType::LargeBinary => Arc::new(decode_binary::<i64>(rows, options)),
         DataType::Utf8 => Arc::new(decode_string::<i32>(rows, options)),
         DataType::LargeUtf8 => Arc::new(decode_string::<i64>(rows, options)),
-        DataType::Decimal128(p, s) => {
-            Arc::new(decode_decimal::<16, Decimal128Type>(rows, options, *p, 
*s))
-        }
-        DataType::Decimal256(p, s) => {
-            Arc::new(decode_decimal::<32, Decimal256Type>(rows, options, *p, 
*s))
-        }
+        DataType::Decimal128(p, s) => Arc::new(
+            decode_primitive::<Decimal128Type>(rows, options)
+                .with_precision_and_scale(*p, *s)
+                .unwrap(),
+        ),
+        DataType::Decimal256(p, s) => Arc::new(
+            decode_primitive::<Decimal256Type>(rows, options)
+                .with_precision_and_scale(*p, *s)
+                .unwrap(),
+        ),
         DataType::Dictionary(k, v) => match k.as_ref() {
             DataType::Int8 => Arc::new(decode_dictionary::<Int8Type>(
                 interner.unwrap(),
@@ -795,6 +793,64 @@ mod tests {
         }
     }
 
+    #[test]
+    fn test_decimal128() {
+        let mut converter = RowConverter::new(vec![SortField::new(
+            DataType::Decimal128(DECIMAL128_MAX_PRECISION, 7),
+        )]);
+        let col = Arc::new(
+            Decimal128Array::from_iter([
+                None,
+                Some(i128::MIN),
+                Some(-13),
+                Some(46_i128),
+                Some(5456_i128),
+                Some(i128::MAX),
+            ])
+            .with_precision_and_scale(38, 7)
+            .unwrap(),
+        ) as ArrayRef;
+
+        let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap();
+        for i in 0..rows.num_rows() - 1 {
+            assert!(rows.row(i) < rows.row(i + 1));
+        }
+
+        let back = converter.convert_rows(&rows).unwrap();
+        assert_eq!(back.len(), 1);
+        assert_eq!(col.as_ref(), back[0].as_ref())
+    }
+
+    #[test]
+    fn test_decimal256() {
+        let mut converter = RowConverter::new(vec![SortField::new(
+            DataType::Decimal256(DECIMAL256_MAX_PRECISION, 7),
+        )]);
+        let col = Arc::new(
+            Decimal256Array::from_iter([
+                None,
+                Some(i256::MIN),
+                Some(i256::from_parts(0, -1)),
+                Some(i256::from_parts(u128::MAX, -1)),
+                Some(i256::from_parts(u128::MAX, 0)),
+                Some(i256::from_parts(0, 46_i128)),
+                Some(i256::from_parts(5, 46_i128)),
+                Some(i256::MAX),
+            ])
+            .with_precision_and_scale(DECIMAL256_MAX_PRECISION, 7)
+            .unwrap(),
+        ) as ArrayRef;
+
+        let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap();
+        for i in 0..rows.num_rows() - 1 {
+            assert!(rows.row(i) < rows.row(i + 1));
+        }
+
+        let back = converter.convert_rows(&rows).unwrap();
+        assert_eq!(back.len(), 1);
+        assert_eq!(col.as_ref(), back[0].as_ref())
+    }
+
     #[test]
     fn test_bool() {
         let mut converter = 
RowConverter::new(vec![SortField::new(DataType::Boolean)]);

Reply via email to