pitrou commented on code in PR #35997:
URL: https://github.com/apache/arrow/pull/35997#discussion_r1229431310


##########
python/pyarrow/tests/test_compute.py:
##########
@@ -1822,6 +1825,128 @@ def test_fsl_to_fsl_cast(value_type):
         fsl.cast(cast_type)
 
 
+DecimalTypeTraits = namedtuple('DecimalTypeTraits',
+                               ('name', 'factory', 'max_precision'))
+
+FloatToDecimalCase = namedtuple('FloatToDecimalCase',
+                                ('precision', 'scale', 'float_val'))
+
+decimal_type_traits = [DecimalTypeTraits('decimal128', pa.decimal128, 38),
+                       DecimalTypeTraits('decimal256', pa.decimal256, 76)]
+
+
+def largest_scaled_float_not_above(val, scale):
+    """
+    Find the largest float f such as `f * 10**scale <= val`
+    """
+    assert val >= 0
+    assert scale >= 0
+    float_val = float(val) / 10**scale
+    if float_val * 10**scale > val:
+        # Take the float just below... it *should* satisfy
+        float_val = np.nextafter(float_val, 0.0)
+        if float_val * 10**scale > val:
+            float_val = np.nextafter(float_val, 0.0)
+    assert float_val * 10**scale <= val
+    return float_val
+
+
+def scaled_float(int_val, scale):
+    """
+    Return a float representation (possibly approximate) of `int_val**-scale`
+    """
+    assert isinstance(int_val, int)
+    unscaled = decimal.Decimal(int_val)
+    scaled = unscaled.scaleb(-scale)
+    float_val = float(scaled)
+    return float_val
+
+
+def integral_float_to_decimal_cast_cases(float_ty, max_precision):
+    """
+    Return FloatToDecimalCase instances with integral values.
+    """
+    mantissa_digits = 16
+    for precision in range(1, max_precision, 3):
+        for scale in range(0, precision, 2):
+            yield FloatToDecimalCase(precision, scale, 0.0)
+            yield FloatToDecimalCase(precision, scale, 1.0)
+            epsilon = 10**max(precision - mantissa_digits, scale)
+            abs_maxval = largest_scaled_float_not_above(
+                10**precision - epsilon, scale)
+            yield FloatToDecimalCase(precision, scale, abs_maxval)
+
+
+def real_float_to_decimal_cast_cases(float_ty, max_precision):
+    """
+    Return FloatToDecimalCase instances with real values.
+    """
+    mantissa_digits = 16
+    for precision in range(1, max_precision, 3):
+        for scale in range(0, precision, 2):
+            epsilon = 2 * 10**max(precision - mantissa_digits, 0)
+            abs_minval = largest_scaled_float_not_above(epsilon, scale)
+            abs_maxval = largest_scaled_float_not_above(
+                10**precision - epsilon, scale)
+            yield FloatToDecimalCase(precision, scale, abs_minval)
+            yield FloatToDecimalCase(precision, scale, abs_maxval)
+
+
+def random_float_to_decimal_cast_cases(float_ty, max_precision):
+    """
+    Return random-generated FloatToDecimalCase instances.
+    """
+    r = random.Random(42)
+    for precision in range(1, max_precision, 6):
+        for scale in range(0, precision, 4):
+            for i in range(20):
+                unscaled = r.randrange(0, 10**precision)
+                float_val = scaled_float(unscaled, scale)
+                assert float_val * 10**scale < 10**precision
+                yield FloatToDecimalCase(precision, scale, float_val)
+
+
+def check_cast_float_to_decimal(float_ty, float_val, decimal_ty, decimal_ctx,
+                                max_precision):
+    # Use the Python decimal module to build the expected result
+    # using the right precision
+    decimal_ctx.prec = decimal_ty.precision
+    decimal_ctx.rounding = decimal.ROUND_HALF_EVEN
+    expected = decimal_ctx.create_decimal_from_float(float_val)
+    # Round `expected` to `scale` digits after the decimal point
+    expected = expected.quantize(decimal.Decimal(1).scaleb(-decimal_ty.scale))
+    s = pa.scalar(float_val, type=float_ty)
+    actual = pc.cast(s, decimal_ty).as_py()
+    if actual != expected:
+        # Allow the last digit to vary. The tolerance is higher for
+        # very high precisions as rounding errors can accumulate in
+        # the iterative algorithm (GH-35576).
+        diff_digits = abs(actual - expected) * 10**decimal_ty.scale
+        limit = 2 if decimal_ty.precision < max_precision - 1 else 4
+        assert diff_digits <= limit, (
+            f"float_val = {float_val!r}, precision={decimal_ty.precision}, "
+            f"expected = {expected!r}, actual = {actual!r}, "
+            f"diff_digits = {diff_digits!r}")
+
+
+# XXX Cannot test float32 as case generators above assume float64

Review Comment:
   Not sure. I wouldn't create an issue for it as the effort-benefit ratio is 
probably low.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to