This is an automated email from the ASF dual-hosted git repository.

paleolimbot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-nanoarrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 59a281cc feat(python): Support Decimal types in convert to Python 
(#425)
59a281cc is described below

commit 59a281cca2ed63925931f6fd0dc492a7dbbb4530
Author: Dewey Dunnington <[email protected]>
AuthorDate: Mon Apr 15 09:44:40 2024 -0300

    feat(python): Support Decimal types in convert to Python (#425)
    
    I experimented with a few different methods here...I think this one has
    a good balance of speed and not messing with the global precision.
    
    ```python
    import pyarrow as pa
    import decimal
    from nanoarrow.iterator import iter_py
    
    items = [decimal.Decimal("12.3450"), None, decimal.Decimal("1234567.3456")]
    array = pa.array(items, pa.decimal128(11, 4))
    list(iter_py(array))
    #> [Decimal('12.3450'), None, Decimal('1234567.3456')]
    ```
    
    This seems to be vaguely on par with pyarrow convert:
    
    ```python
    import pyarrow as pa
    import decimal
    import numpy as np
    from nanoarrow.iterator import iter_py
    
    floats = np.random.random(int(1e6))
    items = [decimal.Decimal(item) for item in floats]
    array = pa.array(items)
    
    %timeit array.to_pylist()
    #> 799 ms ± 6.24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    
    %timeit list(iter_py(array))
    #> 431 ms ± 8.65 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    ```
---
 python/src/nanoarrow/iterator.py | 31 +++++++++++++++++++++++++++++--
 python/tests/test_iterator.py    | 16 ++++++++++++++++
 2 files changed, 45 insertions(+), 2 deletions(-)

diff --git a/python/src/nanoarrow/iterator.py b/python/src/nanoarrow/iterator.py
index 111dbe0d..5eff7a9e 100644
--- a/python/src/nanoarrow/iterator.py
+++ b/python/src/nanoarrow/iterator.py
@@ -263,6 +263,33 @@ class PyIterator(ArrayViewIterator):
             for start, end in zip(starts, ends):
                 yield bytes(data[start:end])
 
+    def _decimal_iter(self, offset, length):
+        from decimal import Context, Decimal
+        from sys import byteorder
+
+        storage = self._primitive_iter(offset, length)
+        precision = self._schema_view.decimal_precision
+
+        # The approach here it to use Decimal(<integer>).scaleb(-scale),
+        # which is a balance between simplicity, performance, and
+        # safety (ensuring that we stay independent from the global precision).
+        # We cache the scaleb and context to avoid doing so in the loop (the
+        # argument to scaleb is transformed to a decimal by the .scaleb()
+        # implementation).
+        #
+        # It would probably be fastest to go straight from binary
+        # to string to decimal, since creating a decimal from a string
+        # appears to be the fastest constructor.
+        scaleb = Decimal(-self._schema_view.decimal_scale)
+        context = Context(prec=precision)
+
+        for item in storage:
+            if item is None:
+                yield None
+            else:
+                int_value = int.from_bytes(item, byteorder)
+                yield Decimal(int_value).scaleb(scaleb, context)
+
     def _date_iter(self, offset, length):
         from datetime import date, timedelta
 
@@ -471,6 +498,8 @@ _ITEMS_ITER_LOOKUP = {
     CArrowType.TIME64: "_time_iter",
     CArrowType.TIMESTAMP: "_timestamp_iter",
     CArrowType.DURATION: "_duration_iter",
+    CArrowType.DECIMAL128: "_decimal_iter",
+    CArrowType.DECIMAL256: "_decimal_iter",
 }
 
 _PRIMITIVE_TYPE_NAMES = [
@@ -490,8 +519,6 @@ _PRIMITIVE_TYPE_NAMES = [
     "INTERVAL_MONTHS",
     "INTERVAL_DAY_TIME",
     "INTERVAL_MONTH_DAY_NANO",
-    "DECIMAL128",
-    "DECIMAL256",
 ]
 
 for type_name in _PRIMITIVE_TYPE_NAMES:
diff --git a/python/tests/test_iterator.py b/python/tests/test_iterator.py
index 72a55414..8cbd66de 100644
--- a/python/tests/test_iterator.py
+++ b/python/tests/test_iterator.py
@@ -16,6 +16,7 @@
 # under the License.
 
 import datetime
+import decimal
 
 import pytest
 from nanoarrow.iterator import (
@@ -334,6 +335,21 @@ def test_iterator_nullable_dictionary():
     assert list(iter_py(sliced)) == ["cde", "ab", "def", "cde", None]
 
 
+def test_iterator_decimal():
+    pa = pytest.importorskip("pyarrow")
+
+    items = [decimal.Decimal("12.3450"), None, decimal.Decimal("1234567.3456")]
+    array = pa.array(items, pa.decimal128(11, 4))
+    assert list(iter_py(array)) == items
+
+    array = pa.array(items, pa.decimal256(11, 4))
+    assert list(iter_py(array)) == items
+
+    # Make sure this isn't affected by user-modified context
+    with decimal.localcontext(decimal.Context(prec=1)):
+        assert list(iter_py(array)) == items
+
+
 def test_iterator_date():
     pa = pytest.importorskip("pyarrow")
 

Reply via email to