felipecrv commented on code in PR #6906:
URL: https://github.com/apache/arrow-rs/pull/6906#discussion_r1894304597
##########
arrow-arith/src/numeric.rs:
##########
@@ -621,6 +772,98 @@ fn interval_op<T: IntervalOp>(
}
}
+/// Perform multiplication between an interval array and a numeric array
+fn interval_mul_op<T: IntervalOp>(
+ op: Op,
+ l: &dyn Array,
+ l_s: bool,
+ r: &dyn Array,
+ r_s: bool,
+) -> Result<ArrayRef, ArrowError> {
Review Comment:
As I said in the first comment, these should become different cases in
```rust
fn interval_op<T: IntervalOp>(
op: Op,
l: &dyn Array,
l_s: bool,
r: &dyn Array,
r_s: bool,
) -> Result<ArrayRef, ArrowError> {
let l = l.as_primitive::<T>();
let r = r.as_primitive::<T>();
match op {
Op::Add | Op::AddWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s,
T::add(l, r))),
Op::Sub | Op::SubWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s,
T::sub(l, r))),
-- NEW CASES HERE --
```
And then instead of relying on `try_op_ref!` that produces array operations
based on single-value functions, you will first convert the interval array to a
integer array (for both mul and div so no need for different traits), then
delegate to either mul of div of numeric inputs (no interval types involved at
this point), then you take that output and convert the desired interval type
with the appropriate kernel. These conversion functions might be defined with
`try_op_ref!`.
##########
arrow-arith/src/numeric.rs:
##########
@@ -230,6 +230,24 @@ fn arithmetic_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum)
-> Result<ArrayRef, A
(Interval(YearMonth), Interval(YearMonth)) =>
interval_op::<IntervalYearMonthType>(op, l, l_scalar, r, r_scalar),
(Interval(DayTime), Interval(DayTime)) =>
interval_op::<IntervalDayTimeType>(op, l, l_scalar, r, r_scalar),
(Interval(MonthDayNano), Interval(MonthDayNano)) =>
interval_op::<IntervalMonthDayNanoType>(op, l, l_scalar, r, r_scalar),
+ (Interval(unit), rhs) if rhs.is_numeric() && matches!(op, Op::Mul |
Op::MulWrapping) =>
+ match unit {
+ YearMonth => interval_mul_op::<IntervalYearMonthType>(op, l,
l_scalar, r, r_scalar),
+ DayTime => interval_mul_op::<IntervalDayTimeType>(op, l,
l_scalar, r, r_scalar),
+ MonthDayNano =>
interval_mul_op::<IntervalMonthDayNanoType>(op, l, l_scalar, r, r_scalar),
+ },
+ (lhs, Interval(unit)) if lhs.is_numeric() && matches!(op, Op::Mul |
Op::MulWrapping) =>
+ match unit {
+ YearMonth => interval_mul_op::<IntervalYearMonthType>(op, r,
r_scalar, l, l_scalar),
+ DayTime => interval_mul_op::<IntervalDayTimeType>(op, r,
r_scalar, l, l_scalar),
+ MonthDayNano =>
interval_mul_op::<IntervalMonthDayNanoType>(op, r, r_scalar, l, l_scalar),
+ },
+ (Interval(unit), rhs) if rhs.is_numeric() && matches!(op, Op::Div) =>
+ match unit {
+ YearMonth => interval_div_op::<IntervalYearMonthType>(op, l,
l_scalar, r, r_scalar),
+ DayTime => interval_div_op::<IntervalDayTimeType>(op, l,
l_scalar, r, r_scalar),
+ MonthDayNano =>
interval_div_op::<IntervalMonthDayNanoType>(op, l, l_scalar, r, r_scalar),
+ },
Review Comment:
I think it makes more sense to keep these patterns concerned only with the
lhs/rhs types and leaving the switch on `op` to `interval_op` instead of adding
`interval_mul_op/interval_div_op`. Since the preparation to dispatching a mul
or a div are similar, keeping it all inside `interval_op` and switching on `op`
may reduce binary size.
##########
arrow-arith/src/numeric.rs:
##########
@@ -574,6 +574,17 @@ trait IntervalOp: ArrowPrimitiveType {
fn div_float(left: Self::Native, right: f64) -> Result<Self::Native,
ArrowError>;
}
+/// Helper function to safely convert f64 to i32, checking for overflow and
invalid values
+fn f64_to_i32(value: f64) -> Result<i32, ArrowError> {
+ if !value.is_finite() || value > i32::MAX as f64 || value < i32::MIN as
f64 {
+ Err(ArrowError::ComputeError(
+ "Division result out of i32 range".to_string(),
+ ))
+ } else {
+ Ok(value as i32)
+ }
Review Comment:
If this is approached in the way I described above you won't be converting
floats to integers. When you get a floating point array representing the number
of milliseconds and you need to convert that to `IntervalDayTimeType` you will
be dividing the input to MILLIS_IN_A_DAY and round the rest which comfortably
fits in a 32-bit integer.
##########
arrow-arith/src/numeric.rs:
##########
@@ -550,6 +568,21 @@ date!(Date64Type);
trait IntervalOp: ArrowPrimitiveType {
fn add(left: Self::Native, right: Self::Native) -> Result<Self::Native,
ArrowError>;
fn sub(left: Self::Native, right: Self::Native) -> Result<Self::Native,
ArrowError>;
+ fn mul_int(left: Self::Native, right: i32) -> Result<Self::Native,
ArrowError>;
+ fn mul_float(left: Self::Native, right: f64) -> Result<Self::Native,
ArrowError>;
+ fn div_int(left: Self::Native, right: i32) -> Result<Self::Native,
ArrowError>;
+ fn div_float(left: Self::Native, right: f64) -> Result<Self::Native,
ArrowError>;
Review Comment:
Instead of instantiating a single-value `interval [/*] number` operation to
build all the array operations we want, we can approach this in a way that
avoids the combinatorial explosion.
To implement `interval_array [/*] number` operations we need operations that
convert intervals to integers that count the number of the smallest unit of the
interval type, regular int and float multiplication/divisions (which already
exist), and then the conversion back to interval types. These conversion
operations are parameterized by a single type so the number of specializations
isn't a product of the number of interval types and the number of int and float
types. All these operations are at the array level and not at the single-value
level.
- `IntervalYearMonthType` is already an `int32` array (number of months)
[1], so just fallback to `int32` x ... kernels and convert the result to
`int32` number of months again (conversion might not even be needed depending
on what `rhs` is)
- `IntervalDayTimeType` is days and milliseconds (both int32) so if the
whole array is converted to an `int64` array of milliseconds you can delegate
to regular `*` `/` and convert back the result
- `IntervalMonthDayNanoType` similar idea.
[1]
https://github.com/apache/arrow/blob/02a165922e46e5ed6dd3ed2446141cd0922a7c54/format/Schema.fbs#L398
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]