[
https://issues.apache.org/jira/browse/ARROW-13917?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
]
Will Jones closed ARROW-13917.
------------------------------
Resolution: Not A Problem
> [Gandiva] Add helper to determine valid decimal function return type
> --------------------------------------------------------------------
>
> Key: ARROW-13917
> URL: https://issues.apache.org/jira/browse/ARROW-13917
> Project: Apache Arrow
> Issue Type: Improvement
> Components: C++ - Gandiva
> Reporter: Will Jones
> Priority: Minor
>
> To evaluate a Gandiva function, you need to pass it's return type. For most
> types, we can look up the possible return types by using the
> `GetRegisteredFunctionSignatures` method, but those don't include details of
> the precision and scale parameters of the decimal type.
> Specifying the precision and scale parameters of the decimal type is left up
> to the user, but if the user gets it wrong, they can get invalid answers.
> See the reproducible example at the bottom.
> The precision and scale of the return type depend on the input types and the
> implementation of the decimal operations. Given the variation of logic across
> different functions (add, divide, trunc, round), it would be best if we were
> able to provide some utility to help the user determine the precise return
> type.
> Now return types aren't unique for every given function name and parameter
> types. For example, `add(date64[ms], int64` can return either `date64[ms]` or
> `timestamp[ms]`. So a generic utility has to return multiple possible return
> types.
> Example of invalid decimal results from bad return type:
> {code:python}
> from decimal import Decimal
> import pyarrow as pa
> from pyarrow.gandiva import TreeExprBuilder, make_projector
> def call_on_value(func, values, params, out_type):
> builder = TreeExprBuilder()
>
> param_literals = []
> for param, param_type in params:
> param_literals.append(builder.make_literal(param, param_type))
>
> inputs = []
> arrays = []
> for i, value in enumerate(values):
> inputs.append(builder.make_field(pa.field(str(i), value[1])))
> arrays.append(pa.array([value[0]], value[1]))
>
> record_batch = pa.record_batch(arrays, [str(i) for i in
> range(len(values))])
>
> func_x = builder.make_function(func, inputs + param_literals, out_type)
>
> expressions = [builder.make_expression(func_x, pa.field('result',
> out_type))]
>
>
> projector = make_projector(record_batch.schema, expressions,
> pa.default_memory_pool())
>
> return projector.evaluate(record_batch)
> call_on_value(
> 'round',
> (Decimal("123.459"), pa.decimal128(28, 3)),
> [(2, pa.int32())],
> pa.decimal128(28, 3)
> )
> # Returns: 123.459 (not rounded!)
> call_on_value(
> 'round',
> (Decimal("123.459"), pa.decimal128(28, 3)),
> [(-2, pa.int32())],
> pa.decimal128(28, 3)
> )
> # Returns: 0.100 (😵)
> {code}
--
This message was sent by Atlassian Jira
(v8.3.4#803005)