Will Jones created ARROW-13917:
----------------------------------

             Summary: [Gandiva] Add helper to determine valid decimal function 
return type
                 Key: ARROW-13917
                 URL: https://issues.apache.org/jira/browse/ARROW-13917
             Project: Apache Arrow
          Issue Type: Improvement
          Components: C++ - Gandiva
            Reporter: Will Jones


To evaluate a Gandiva function, you need to pass it's return type. For most 
types, we can look up the possible return types by using the 
`GetRegisteredFunctionSignatures` method, but those don't include details of 
the precision and scale parameters of the decimal type.

Specifying the precision and scale parameters of the decimal type is left up to 
the user, but if the user  gets it wrong, they can get invalid answers. See the 
reproducible example at the bottom.

The precision and scale of the return type depend on the input types and the 
implementation of the decimal operations. Given the variation of logic across 
different functions (add, divide, trunc, round), it would be best if we were 
able to provide some utility to help the user determine the precise return type.

Now return types aren't unique for every given function name and parameter 
types. For example, `add(date64[ms], int64` can return either `date64[ms]` or 
`timestamp[ms]`. So a generic utility has to return multiple possible return 
types.


Example of invalid decimal results from bad return type:

{code:python}
from decimal import Decimal
import pyarrow as pa
from pyarrow.gandiva import TreeExprBuilder, make_projector

def call_on_value(func, values, params, out_type):
    builder = TreeExprBuilder()
    
    param_literals = []
    for param, param_type in params:
        param_literals.append(builder.make_literal(param, param_type))
    
    inputs = []
    arrays = []
    for i, value in enumerate(values):
        inputs.append(builder.make_field(pa.field(str(i), value[1])))
        arrays.append(pa.array([value[0]], value[1]))
    
    record_batch = pa.record_batch(arrays, [str(i) for i in range(len(values))])
    
    func_x = builder.make_function(func, inputs + param_literals, out_type)
    
    expressions = [builder.make_expression(func_x, pa.field('result', 
out_type))]
    
    
    projector = make_projector(record_batch.schema, expressions, 
pa.default_memory_pool())
    
    return projector.evaluate(record_batch)

call_on_value(
    'round',
    (Decimal("123.459"), pa.decimal128(28, 3)),
    [(2, pa.int32())],
    pa.decimal128(28, 3)
)
# Returns: 123.459 (not rounded!)

call_on_value(
    'round',
    (Decimal("123.459"), pa.decimal128(28, 3)),
    [(-2, pa.int32())],
    pa.decimal128(28, 3)
)
# Returns:  0.100 (😵)
{code}




--
This message was sent by Atlassian Jira
(v8.3.4#803005)

Reply via email to