[ 
https://issues.apache.org/jira/browse/ARROW-13917?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
 ]

Will Jones closed ARROW-13917.
------------------------------
    Resolution: Not A Problem

> [Gandiva] Add helper to determine valid decimal function return type
> --------------------------------------------------------------------
>
>                 Key: ARROW-13917
>                 URL: https://issues.apache.org/jira/browse/ARROW-13917
>             Project: Apache Arrow
>          Issue Type: Improvement
>          Components: C++ - Gandiva
>            Reporter: Will Jones
>            Priority: Minor
>
> To evaluate a Gandiva function, you need to pass it's return type. For most 
> types, we can look up the possible return types by using the 
> `GetRegisteredFunctionSignatures` method, but those don't include details of 
> the precision and scale parameters of the decimal type.
> Specifying the precision and scale parameters of the decimal type is left up 
> to the user, but if the user  gets it wrong, they can get invalid answers. 
> See the reproducible example at the bottom.
> The precision and scale of the return type depend on the input types and the 
> implementation of the decimal operations. Given the variation of logic across 
> different functions (add, divide, trunc, round), it would be best if we were 
> able to provide some utility to help the user determine the precise return 
> type.
> Now return types aren't unique for every given function name and parameter 
> types. For example, `add(date64[ms], int64` can return either `date64[ms]` or 
> `timestamp[ms]`. So a generic utility has to return multiple possible return 
> types.
> Example of invalid decimal results from bad return type:
> {code:python}
> from decimal import Decimal
> import pyarrow as pa
> from pyarrow.gandiva import TreeExprBuilder, make_projector
> def call_on_value(func, values, params, out_type):
>     builder = TreeExprBuilder()
>     
>     param_literals = []
>     for param, param_type in params:
>         param_literals.append(builder.make_literal(param, param_type))
>     
>     inputs = []
>     arrays = []
>     for i, value in enumerate(values):
>         inputs.append(builder.make_field(pa.field(str(i), value[1])))
>         arrays.append(pa.array([value[0]], value[1]))
>     
>     record_batch = pa.record_batch(arrays, [str(i) for i in 
> range(len(values))])
>     
>     func_x = builder.make_function(func, inputs + param_literals, out_type)
>     
>     expressions = [builder.make_expression(func_x, pa.field('result', 
> out_type))]
>     
>     
>     projector = make_projector(record_batch.schema, expressions, 
> pa.default_memory_pool())
>     
>     return projector.evaluate(record_batch)
> call_on_value(
>     'round',
>     (Decimal("123.459"), pa.decimal128(28, 3)),
>     [(2, pa.int32())],
>     pa.decimal128(28, 3)
> )
> # Returns: 123.459 (not rounded!)
> call_on_value(
>     'round',
>     (Decimal("123.459"), pa.decimal128(28, 3)),
>     [(-2, pa.int32())],
>     pa.decimal128(28, 3)
> )
> # Returns:  0.100 (😵)
> {code}



--
This message was sent by Atlassian Jira
(v8.3.4#803005)

Reply via email to