zeroshade commented on code in PR #572:
URL: https://github.com/apache/arrow-go/pull/572#discussion_r2528905146
##########
arrow/compute/internal/kernels/rounding.go:
##########
@@ -807,3 +808,681 @@ func FixedRoundDecimalExec[T decimal128.Num |
decimal256.Num](mode RoundMode) ex
}
panic("should never get here")
}
+
+// RoundTemporalUnit represents units supported for temporal rounding
+type RoundTemporalUnit int8
+
+const (
+ RoundTemporalYear RoundTemporalUnit = iota
+ RoundTemporalQuarter
+ RoundTemporalMonth
+ RoundTemporalWeek
+ RoundTemporalDay
+ RoundTemporalHour
+ RoundTemporalMinute
+ RoundTemporalSecond
+ RoundTemporalMillisecond
+ RoundTemporalMicrosecond
+ RoundTemporalNanosecond
+)
+
+// RoundTemporalOptions provides configuration for temporal rounding operations
+type RoundTemporalOptions struct {
+ // Multiple is the number of units to round to. Must be positive.
+ Multiple int64
+ // Unit is the rounding unit (day, hour, etc.)
+ Unit RoundTemporalUnit
+ // WeekStartsMonday determines the start of the week for week-based
rounding
+ WeekStartsMonday bool
+ // CeilIsStrictlyGreater: if true, ceil returns a value strictly
greater than input
+ CeilIsStrictlyGreater bool
+ // CalendarBasedOrigin: if true, use calendar units as origin (e.g.,
start of day for hours)
+ CalendarBasedOrigin bool
+}
+
+func (RoundTemporalOptions) TypeName() string { return "RoundTemporalOptions" }
+
+type roundTemporalState struct {
+ RoundTemporalOptions
+ mode RoundMode
+
+ // Pre-calculated values to avoid repeated computation
+ unitNanos int64 // Duration of the unit in nanoseconds
+ roundingInterval int64 // unitNanos * Multiple
+ isSubDay bool // true if this is a sub-day unit (can use fast
path)
+ useCalendarOrigin bool // true if using calendar-based origin
+}
+
+func InitRoundTemporalState(_ *exec.KernelCtx, args exec.KernelInitArgs)
(exec.KernelState, error) {
+ var rs roundTemporalState
+
+ opts, ok := args.Options.(*RoundTemporalOptions)
+ if ok {
+ rs.RoundTemporalOptions = *opts
+ } else {
+ if rs.RoundTemporalOptions, ok =
args.Options.(RoundTemporalOptions); !ok {
+ return nil, fmt.Errorf("%w: attempted to initialize
kernel state from invalid function options",
+ arrow.ErrInvalid)
+ }
+ }
+
+ if rs.Multiple <= 0 {
+ return nil, fmt.Errorf("%w: rounding multiple must be
positive", arrow.ErrInvalid)
+ }
+
+ // Pre-calculate constants for this rounding operation
+ rs.unitNanos, rs.isSubDay = unitInNanos(rs.Unit)
+ if rs.isSubDay {
+ rs.roundingInterval = rs.unitNanos * rs.Multiple
+ rs.useCalendarOrigin = rs.CalendarBasedOrigin && rs.Unit <=
RoundTemporalDay
+ }
+
+ return rs, nil
+}
+
+// unitInNanos returns (nanoseconds, hasFixedDuration) for a temporal unit.
+// Returns false for calendar units with variable durations (year, quarter,
month, week).
+func unitInNanos(unit RoundTemporalUnit) (int64, bool) {
+ switch unit {
+ case RoundTemporalNanosecond:
+ return 1, true
+ case RoundTemporalMicrosecond:
+ return 1000, true
+ case RoundTemporalMillisecond:
+ return 1000000, true
+ case RoundTemporalSecond:
+ return 1000000000, true
+ case RoundTemporalMinute:
+ return 60 * 1000000000, true
+ case RoundTemporalHour:
+ return 3600 * 1000000000, true
+ case RoundTemporalDay:
+ return 86400 * 1000000000, true
+ default:
+ return 0, false
+ }
+}
+
+// roundTimestamp rounds a timestamp value according to the specified options.
+// tz specifies the timezone for calendar-aware rounding (nil defaults to UTC).
+func roundTimestamp(ts int64, inputUnit arrow.TimeUnit, tz *time.Location,
opts roundTemporalState) (int64, error) {
+ if tz == nil {
+ tz = time.UTC
+ }
+
+ // Calendar units with variable duration (year, quarter, month, week)
require date arithmetic
+ if !opts.isSubDay {
+ tsNanos := convertToNanos(ts, inputUnit)
+ return roundTimestampCalendar(tsNanos, inputUnit, tz, opts)
+ }
+
+ // Day rounding with timezone requires calendar arithmetic (days vary:
23/24/25 hours due to DST)
+ isUTC := tz == time.UTC || tz.String() == "UTC"
+ if !isUTC && opts.Unit == RoundTemporalDay {
+ tsNanos := convertToNanos(ts, inputUnit)
+ return roundTimestampCalendar(tsNanos, inputUnit, tz, opts)
+ }
+
+ // Sub-day units (hour, minute, second, etc.) use fixed-duration
arithmetic
+ // Fast path: round directly in input unit if possible (no origin,
compatible units)
+ if canRoundInInputUnit(inputUnit, opts.unitNanos) &&
!opts.useCalendarOrigin {
+ intervalInInputUnit := opts.roundingInterval /
unitScaleFactor(inputUnit)
+ rounded := roundToMultipleInt64(ts, intervalInInputUnit,
opts.mode, opts.CeilIsStrictlyGreater)
+ return rounded, nil
+ }
+
+ // Slow path: convert to nanoseconds for calendar origin or
incompatible units
+ tsNanos := convertToNanos(ts, inputUnit)
+
+ var origin int64 = 0
+ if opts.useCalendarOrigin {
+ // Calendar origin: round relative to start of day
(timezone-aware if tz != nil)
+ if tz != nil {
+ t := time.Unix(0, tsNanos).In(tz)
+ startOfDay := time.Date(t.Year(), t.Month(), t.Day(),
0, 0, 0, 0, tz)
+ origin = startOfDay.UnixNano()
+ } else {
+ origin = tsNanos
+ }
+ }
+
+ adjusted := tsNanos - origin
+ rounded := roundToMultipleInt64(adjusted, opts.roundingInterval,
opts.mode, opts.CeilIsStrictlyGreater)
+ result := origin + rounded
+
+ return convertFromNanos(result, inputUnit), nil
+}
+
+// unitScaleFactor returns nanoseconds per unit for the given time unit
+func unitScaleFactor(unit arrow.TimeUnit) int64 {
+ switch unit {
+ case arrow.Second:
+ return 1_000_000_000
+ case arrow.Millisecond:
+ return 1_000_000
+ case arrow.Microsecond:
+ return 1_000
+ case arrow.Nanosecond:
+ return 1
+ default:
+ return 1
+ }
+}
+
+// canRoundInInputUnit checks if rounding can be done in the input unit
+// without converting to nanoseconds (true when rounding interval is evenly
divisible).
+func canRoundInInputUnit(inputUnit arrow.TimeUnit, roundingIntervalNanos
int64) bool {
+ return roundingIntervalNanos%unitScaleFactor(inputUnit) == 0
+}
+
+// convertToNanos converts a timestamp value to nanoseconds
+func convertToNanos(ts int64, unit arrow.TimeUnit) int64 {
+ return ts * unitScaleFactor(unit)
+}
+
+// convertFromNanos converts a nanosecond timestamp to the specified unit
+func convertFromNanos(nanos int64, unit arrow.TimeUnit) int64 {
+ return nanos / unitScaleFactor(unit)
Review Comment:
```suggestion
return nanos / int64(unit.Multiplier())
```
##########
arrow/compute/internal/kernels/rounding.go:
##########
@@ -807,3 +808,681 @@ func FixedRoundDecimalExec[T decimal128.Num |
decimal256.Num](mode RoundMode) ex
}
panic("should never get here")
}
+
+// RoundTemporalUnit represents units supported for temporal rounding
+type RoundTemporalUnit int8
+
+const (
+ RoundTemporalYear RoundTemporalUnit = iota
+ RoundTemporalQuarter
+ RoundTemporalMonth
+ RoundTemporalWeek
+ RoundTemporalDay
+ RoundTemporalHour
+ RoundTemporalMinute
+ RoundTemporalSecond
+ RoundTemporalMillisecond
+ RoundTemporalMicrosecond
+ RoundTemporalNanosecond
+)
+
+// RoundTemporalOptions provides configuration for temporal rounding operations
+type RoundTemporalOptions struct {
+ // Multiple is the number of units to round to. Must be positive.
+ Multiple int64
+ // Unit is the rounding unit (day, hour, etc.)
+ Unit RoundTemporalUnit
+ // WeekStartsMonday determines the start of the week for week-based
rounding
+ WeekStartsMonday bool
+ // CeilIsStrictlyGreater: if true, ceil returns a value strictly
greater than input
+ CeilIsStrictlyGreater bool
+ // CalendarBasedOrigin: if true, use calendar units as origin (e.g.,
start of day for hours)
+ CalendarBasedOrigin bool
+}
+
+func (RoundTemporalOptions) TypeName() string { return "RoundTemporalOptions" }
+
+type roundTemporalState struct {
+ RoundTemporalOptions
+ mode RoundMode
+
+ // Pre-calculated values to avoid repeated computation
+ unitNanos int64 // Duration of the unit in nanoseconds
+ roundingInterval int64 // unitNanos * Multiple
+ isSubDay bool // true if this is a sub-day unit (can use fast
path)
+ useCalendarOrigin bool // true if using calendar-based origin
+}
+
+func InitRoundTemporalState(_ *exec.KernelCtx, args exec.KernelInitArgs)
(exec.KernelState, error) {
+ var rs roundTemporalState
+
+ opts, ok := args.Options.(*RoundTemporalOptions)
+ if ok {
+ rs.RoundTemporalOptions = *opts
+ } else {
+ if rs.RoundTemporalOptions, ok =
args.Options.(RoundTemporalOptions); !ok {
+ return nil, fmt.Errorf("%w: attempted to initialize
kernel state from invalid function options",
+ arrow.ErrInvalid)
+ }
+ }
+
+ if rs.Multiple <= 0 {
+ return nil, fmt.Errorf("%w: rounding multiple must be
positive", arrow.ErrInvalid)
+ }
+
+ // Pre-calculate constants for this rounding operation
+ rs.unitNanos, rs.isSubDay = unitInNanos(rs.Unit)
+ if rs.isSubDay {
+ rs.roundingInterval = rs.unitNanos * rs.Multiple
+ rs.useCalendarOrigin = rs.CalendarBasedOrigin && rs.Unit <=
RoundTemporalDay
+ }
+
+ return rs, nil
+}
+
+// unitInNanos returns (nanoseconds, hasFixedDuration) for a temporal unit.
+// Returns false for calendar units with variable durations (year, quarter,
month, week).
+func unitInNanos(unit RoundTemporalUnit) (int64, bool) {
+ switch unit {
+ case RoundTemporalNanosecond:
+ return 1, true
+ case RoundTemporalMicrosecond:
+ return 1000, true
+ case RoundTemporalMillisecond:
+ return 1000000, true
+ case RoundTemporalSecond:
+ return 1000000000, true
+ case RoundTemporalMinute:
+ return 60 * 1000000000, true
+ case RoundTemporalHour:
+ return 3600 * 1000000000, true
+ case RoundTemporalDay:
+ return 86400 * 1000000000, true
+ default:
+ return 0, false
+ }
+}
+
+// roundTimestamp rounds a timestamp value according to the specified options.
+// tz specifies the timezone for calendar-aware rounding (nil defaults to UTC).
+func roundTimestamp(ts int64, inputUnit arrow.TimeUnit, tz *time.Location,
opts roundTemporalState) (int64, error) {
+ if tz == nil {
+ tz = time.UTC
+ }
+
+ // Calendar units with variable duration (year, quarter, month, week)
require date arithmetic
+ if !opts.isSubDay {
+ tsNanos := convertToNanos(ts, inputUnit)
+ return roundTimestampCalendar(tsNanos, inputUnit, tz, opts)
+ }
+
+ // Day rounding with timezone requires calendar arithmetic (days vary:
23/24/25 hours due to DST)
+ isUTC := tz == time.UTC || tz.String() == "UTC"
+ if !isUTC && opts.Unit == RoundTemporalDay {
+ tsNanos := convertToNanos(ts, inputUnit)
+ return roundTimestampCalendar(tsNanos, inputUnit, tz, opts)
+ }
+
+ // Sub-day units (hour, minute, second, etc.) use fixed-duration
arithmetic
+ // Fast path: round directly in input unit if possible (no origin,
compatible units)
+ if canRoundInInputUnit(inputUnit, opts.unitNanos) &&
!opts.useCalendarOrigin {
+ intervalInInputUnit := opts.roundingInterval /
unitScaleFactor(inputUnit)
+ rounded := roundToMultipleInt64(ts, intervalInInputUnit,
opts.mode, opts.CeilIsStrictlyGreater)
+ return rounded, nil
+ }
+
+ // Slow path: convert to nanoseconds for calendar origin or
incompatible units
+ tsNanos := convertToNanos(ts, inputUnit)
+
+ var origin int64 = 0
+ if opts.useCalendarOrigin {
+ // Calendar origin: round relative to start of day
(timezone-aware if tz != nil)
+ if tz != nil {
+ t := time.Unix(0, tsNanos).In(tz)
+ startOfDay := time.Date(t.Year(), t.Month(), t.Day(),
0, 0, 0, 0, tz)
+ origin = startOfDay.UnixNano()
+ } else {
+ origin = tsNanos
+ }
+ }
+
+ adjusted := tsNanos - origin
+ rounded := roundToMultipleInt64(adjusted, opts.roundingInterval,
opts.mode, opts.CeilIsStrictlyGreater)
+ result := origin + rounded
+
+ return convertFromNanos(result, inputUnit), nil
+}
+
+// unitScaleFactor returns nanoseconds per unit for the given time unit
+func unitScaleFactor(unit arrow.TimeUnit) int64 {
+ switch unit {
+ case arrow.Second:
+ return 1_000_000_000
+ case arrow.Millisecond:
+ return 1_000_000
+ case arrow.Microsecond:
+ return 1_000
+ case arrow.Nanosecond:
+ return 1
Review Comment:
`int64(unit.Multiplier())` already does this, no need for this function
##########
arrow/compute/internal/kernels/rounding.go:
##########
@@ -807,3 +808,681 @@ func FixedRoundDecimalExec[T decimal128.Num |
decimal256.Num](mode RoundMode) ex
}
panic("should never get here")
}
+
+// RoundTemporalUnit represents units supported for temporal rounding
+type RoundTemporalUnit int8
+
+const (
+ RoundTemporalYear RoundTemporalUnit = iota
+ RoundTemporalQuarter
+ RoundTemporalMonth
+ RoundTemporalWeek
+ RoundTemporalDay
+ RoundTemporalHour
+ RoundTemporalMinute
+ RoundTemporalSecond
+ RoundTemporalMillisecond
+ RoundTemporalMicrosecond
+ RoundTemporalNanosecond
+)
+
+// RoundTemporalOptions provides configuration for temporal rounding operations
+type RoundTemporalOptions struct {
+ // Multiple is the number of units to round to. Must be positive.
+ Multiple int64
+ // Unit is the rounding unit (day, hour, etc.)
+ Unit RoundTemporalUnit
+ // WeekStartsMonday determines the start of the week for week-based
rounding
+ WeekStartsMonday bool
+ // CeilIsStrictlyGreater: if true, ceil returns a value strictly
greater than input
+ CeilIsStrictlyGreater bool
+ // CalendarBasedOrigin: if true, use calendar units as origin (e.g.,
start of day for hours)
+ CalendarBasedOrigin bool
+}
+
+func (RoundTemporalOptions) TypeName() string { return "RoundTemporalOptions" }
+
+type roundTemporalState struct {
+ RoundTemporalOptions
+ mode RoundMode
+
+ // Pre-calculated values to avoid repeated computation
+ unitNanos int64 // Duration of the unit in nanoseconds
+ roundingInterval int64 // unitNanos * Multiple
+ isSubDay bool // true if this is a sub-day unit (can use fast
path)
+ useCalendarOrigin bool // true if using calendar-based origin
+}
+
+func InitRoundTemporalState(_ *exec.KernelCtx, args exec.KernelInitArgs)
(exec.KernelState, error) {
+ var rs roundTemporalState
+
+ opts, ok := args.Options.(*RoundTemporalOptions)
+ if ok {
+ rs.RoundTemporalOptions = *opts
+ } else {
+ if rs.RoundTemporalOptions, ok =
args.Options.(RoundTemporalOptions); !ok {
+ return nil, fmt.Errorf("%w: attempted to initialize
kernel state from invalid function options",
+ arrow.ErrInvalid)
+ }
+ }
+
+ if rs.Multiple <= 0 {
+ return nil, fmt.Errorf("%w: rounding multiple must be
positive", arrow.ErrInvalid)
+ }
+
+ // Pre-calculate constants for this rounding operation
+ rs.unitNanos, rs.isSubDay = unitInNanos(rs.Unit)
+ if rs.isSubDay {
+ rs.roundingInterval = rs.unitNanos * rs.Multiple
+ rs.useCalendarOrigin = rs.CalendarBasedOrigin && rs.Unit <=
RoundTemporalDay
+ }
+
+ return rs, nil
+}
+
+// unitInNanos returns (nanoseconds, hasFixedDuration) for a temporal unit.
+// Returns false for calendar units with variable durations (year, quarter,
month, week).
+func unitInNanos(unit RoundTemporalUnit) (int64, bool) {
+ switch unit {
+ case RoundTemporalNanosecond:
+ return 1, true
+ case RoundTemporalMicrosecond:
+ return 1000, true
+ case RoundTemporalMillisecond:
+ return 1000000, true
+ case RoundTemporalSecond:
+ return 1000000000, true
+ case RoundTemporalMinute:
+ return 60 * 1000000000, true
+ case RoundTemporalHour:
+ return 3600 * 1000000000, true
+ case RoundTemporalDay:
+ return 86400 * 1000000000, true
+ default:
+ return 0, false
+ }
+}
+
+// roundTimestamp rounds a timestamp value according to the specified options.
+// tz specifies the timezone for calendar-aware rounding (nil defaults to UTC).
+func roundTimestamp(ts int64, inputUnit arrow.TimeUnit, tz *time.Location,
opts roundTemporalState) (int64, error) {
+ if tz == nil {
+ tz = time.UTC
+ }
+
+ // Calendar units with variable duration (year, quarter, month, week)
require date arithmetic
+ if !opts.isSubDay {
+ tsNanos := convertToNanos(ts, inputUnit)
+ return roundTimestampCalendar(tsNanos, inputUnit, tz, opts)
+ }
+
+ // Day rounding with timezone requires calendar arithmetic (days vary:
23/24/25 hours due to DST)
+ isUTC := tz == time.UTC || tz.String() == "UTC"
+ if !isUTC && opts.Unit == RoundTemporalDay {
+ tsNanos := convertToNanos(ts, inputUnit)
+ return roundTimestampCalendar(tsNanos, inputUnit, tz, opts)
+ }
+
+ // Sub-day units (hour, minute, second, etc.) use fixed-duration
arithmetic
+ // Fast path: round directly in input unit if possible (no origin,
compatible units)
+ if canRoundInInputUnit(inputUnit, opts.unitNanos) &&
!opts.useCalendarOrigin {
+ intervalInInputUnit := opts.roundingInterval /
unitScaleFactor(inputUnit)
+ rounded := roundToMultipleInt64(ts, intervalInInputUnit,
opts.mode, opts.CeilIsStrictlyGreater)
+ return rounded, nil
+ }
+
+ // Slow path: convert to nanoseconds for calendar origin or
incompatible units
+ tsNanos := convertToNanos(ts, inputUnit)
+
+ var origin int64 = 0
+ if opts.useCalendarOrigin {
+ // Calendar origin: round relative to start of day
(timezone-aware if tz != nil)
+ if tz != nil {
+ t := time.Unix(0, tsNanos).In(tz)
+ startOfDay := time.Date(t.Year(), t.Month(), t.Day(),
0, 0, 0, 0, tz)
+ origin = startOfDay.UnixNano()
+ } else {
+ origin = tsNanos
+ }
+ }
+
+ adjusted := tsNanos - origin
+ rounded := roundToMultipleInt64(adjusted, opts.roundingInterval,
opts.mode, opts.CeilIsStrictlyGreater)
+ result := origin + rounded
+
+ return convertFromNanos(result, inputUnit), nil
+}
+
+// unitScaleFactor returns nanoseconds per unit for the given time unit
+func unitScaleFactor(unit arrow.TimeUnit) int64 {
+ switch unit {
+ case arrow.Second:
+ return 1_000_000_000
+ case arrow.Millisecond:
+ return 1_000_000
+ case arrow.Microsecond:
+ return 1_000
+ case arrow.Nanosecond:
+ return 1
+ default:
+ return 1
+ }
+}
+
+// canRoundInInputUnit checks if rounding can be done in the input unit
+// without converting to nanoseconds (true when rounding interval is evenly
divisible).
+func canRoundInInputUnit(inputUnit arrow.TimeUnit, roundingIntervalNanos
int64) bool {
+ return roundingIntervalNanos%unitScaleFactor(inputUnit) == 0
+}
+
+// convertToNanos converts a timestamp value to nanoseconds
+func convertToNanos(ts int64, unit arrow.TimeUnit) int64 {
+ return ts * unitScaleFactor(unit)
+}
+
+// convertFromNanos converts a nanosecond timestamp to the specified unit
+func convertFromNanos(nanos int64, unit arrow.TimeUnit) int64 {
+ return nanos / unitScaleFactor(unit)
+}
+
+func roundToMultipleInt64(value, multiple int64, mode RoundMode, strictCeil
bool) int64 {
+ if multiple == 0 || value%multiple == 0 {
+ if strictCeil && mode == RoundUp {
+ return value + multiple
+ }
+ return value
+ }
+
+ quotient := value / multiple
+ remainder := value % multiple
+
+ switch mode {
+ case RoundDown:
+ if remainder < 0 {
+ return (quotient - 1) * multiple
+ }
+ return quotient * multiple
+ case RoundUp:
+ if remainder > 0 || (strictCeil && remainder == 0) {
+ return (quotient + 1) * multiple
+ }
+ if remainder < 0 {
+ return quotient * multiple
+ }
+ return (quotient + 1) * multiple
+ case HalfUp, HalfDown, HalfToEven:
+ half := multiple / 2
+ absRemainder := remainder
+ if absRemainder < 0 {
+ absRemainder = -absRemainder
+ }
+
+ if absRemainder < half {
+ return quotient * multiple
+ } else if absRemainder > half {
+ if remainder > 0 {
+ return (quotient + 1) * multiple
+ }
+ return (quotient - 1) * multiple
+ } else {
+ // Exactly on the halfway point
+ switch mode {
+ case HalfDown:
+ if remainder > 0 {
+ return quotient * multiple
+ }
+ return (quotient - 1) * multiple
+ case HalfUp:
+ if remainder > 0 {
+ return (quotient + 1) * multiple
+ }
+ return quotient * multiple
+ case HalfToEven:
+ if quotient%2 == 0 {
+ return quotient * multiple
+ }
+ if remainder > 0 {
+ return (quotient + 1) * multiple
+ }
+ return (quotient - 1) * multiple
+ }
+ }
+ }
+ return quotient * multiple
+}
+
+// halfRoundPeriod performs half-rounding by finding the midpoint between
period start and end
+func halfRoundPeriod(t, periodStart, periodEnd time.Time) time.Time {
+ midPoint := periodStart.Add(periodEnd.Sub(periodStart) / 2)
+ if t.Before(midPoint) {
+ return periodStart
+ }
+ return periodEnd
+}
+
+// roundTimestampCalendar handles calendar-based rounding (year, quarter,
month, week, day).
+// Requires date arithmetic for variable-length periods and timezone-aware
boundaries.
+func roundTimestampCalendar(tsNanos int64, inputUnit arrow.TimeUnit, tz
*time.Location, opts roundTemporalState) (int64, error) {
+ // Convert to time.Time for calendar operations in the specified
timezone
+ secs := tsNanos / 1000000000
+ nanos := tsNanos % 1000000000
+ t := time.Unix(secs, nanos).In(tz)
+
+ var rounded time.Time
+
+ switch opts.Unit {
+ case RoundTemporalYear:
+ year := t.Year()
+ roundedYear := (year / int(opts.Multiple)) * int(opts.Multiple)
+ switch opts.mode {
+ case RoundDown:
+ rounded = time.Date(roundedYear, 1, 1, 0, 0, 0, 0, tz)
+ case RoundUp:
+ periodStart := time.Date(roundedYear, 1, 1, 0, 0, 0, 0,
tz)
+ if opts.CeilIsStrictlyGreater || !t.Equal(periodStart) {
+ roundedYear += int(opts.Multiple)
+ rounded = time.Date(roundedYear, 1, 1, 0, 0, 0,
0, tz)
+ } else {
+ rounded = periodStart
+ }
+ default:
+ yearStart := time.Date(roundedYear, 1, 1, 0, 0, 0, 0,
tz)
+ nextYear := roundedYear + int(opts.Multiple)
+ yearEnd := time.Date(nextYear, 1, 1, 0, 0, 0, 0, tz)
+ rounded = halfRoundPeriod(t, yearStart, yearEnd)
+ }
+
+ case RoundTemporalQuarter:
+ // Q1=Jan-Mar, Q2=Apr-Jun, Q3=Jul-Sep, Q4=Oct-Dec
+ month := int(t.Month())
+ year := t.Year()
+ totalQuarters := year*4 + (month-1)/3
+ roundedQuarters := (totalQuarters / int(opts.Multiple)) *
int(opts.Multiple)
+ roundedYear := roundedQuarters / 4
+ roundedQuarter := roundedQuarters % 4
+ roundedMonth := roundedQuarter*3 + 1 // First month of the
quarter
+
+ switch opts.mode {
+ case RoundDown:
+ rounded = time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ case RoundUp:
+ periodStart := time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ if opts.CeilIsStrictlyGreater || !t.Equal(periodStart) {
+ roundedQuarters += int(opts.Multiple)
+ roundedYear = roundedQuarters / 4
+ roundedQuarter = roundedQuarters % 4
+ roundedMonth = roundedQuarter*3 + 1
+ rounded = time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ } else {
+ rounded = periodStart
+ }
+ default:
+ quarterStart := time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ nextQuarterNum := roundedQuarters + int(opts.Multiple)
+ nextYear := nextQuarterNum / 4
+ nextQuarter := nextQuarterNum % 4
+ nextMonth := nextQuarter*3 + 1
+ quarterEnd := time.Date(nextYear,
time.Month(nextMonth), 1, 0, 0, 0, 0, tz)
+ rounded = halfRoundPeriod(t, quarterStart, quarterEnd)
+ }
+
+ case RoundTemporalMonth:
+ month := int(t.Month())
+ year := t.Year()
+ totalMonths := year*12 + month - 1
+ roundedMonths := (totalMonths / int(opts.Multiple)) *
int(opts.Multiple)
+ roundedYear := roundedMonths / 12
+ roundedMonth := (roundedMonths % 12) + 1
+
+ switch opts.mode {
+ case RoundDown:
+ rounded = time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ case RoundUp:
+ periodStart := time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ if opts.CeilIsStrictlyGreater || !t.Equal(periodStart) {
+ roundedMonths += int(opts.Multiple)
+ roundedYear = roundedMonths / 12
+ roundedMonth = (roundedMonths % 12) + 1
+ rounded = time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ } else {
+ rounded = periodStart
+ }
+ default:
+ monthStart := time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ nextMonthNum := roundedMonths + int(opts.Multiple)
+ nextYear := nextMonthNum / 12
+ nextMonth := (nextMonthNum % 12) + 1
+ monthEnd := time.Date(nextYear, time.Month(nextMonth),
1, 0, 0, 0, 0, tz)
+ rounded = halfRoundPeriod(t, monthStart, monthEnd)
+ }
+
+ case RoundTemporalWeek:
+ weekday := int(t.Weekday())
+ if opts.WeekStartsMonday {
+ weekday = (weekday + 6) % 7
+ }
+ startOfWeek := t.AddDate(0, 0, -weekday)
+ startOfWeek = time.Date(startOfWeek.Year(),
startOfWeek.Month(), startOfWeek.Day(), 0, 0, 0, 0, tz)
+
+ // Calculate N-week periods from epoch for Multiple > 1
+ epochInTz := time.Unix(0, 0).In(tz)
+ epochWeekday := int(epochInTz.Weekday())
+ if opts.WeekStartsMonday {
+ epochWeekday = (epochWeekday + 6) % 7
+ }
+ epochWeekStart := epochInTz.AddDate(0, 0, -epochWeekday)
+ epochWeekStart = time.Date(epochWeekStart.Year(),
epochWeekStart.Month(), epochWeekStart.Day(), 0, 0, 0, 0, tz)
+
+ daysSinceEpochWeek :=
int(startOfWeek.Sub(epochWeekStart).Hours() / 24)
+ weeksSinceEpoch := daysSinceEpochWeek / 7
+ roundedWeeks := (weeksSinceEpoch / int(opts.Multiple)) *
int(opts.Multiple)
+ roundedWeekStart := epochWeekStart.AddDate(0, 0, roundedWeeks*7)
+
+ switch opts.mode {
+ case RoundDown:
+ rounded = roundedWeekStart
+ case RoundUp:
+ if opts.CeilIsStrictlyGreater ||
!t.Equal(roundedWeekStart) {
+ rounded = roundedWeekStart.AddDate(0, 0,
7*int(opts.Multiple))
+ } else {
+ rounded = roundedWeekStart
+ }
+ default:
+ weekEnd := roundedWeekStart.AddDate(0, 0,
7*int(opts.Multiple))
+ rounded = halfRoundPeriod(t, roundedWeekStart, weekEnd)
+ }
+
+ case RoundTemporalDay:
+ startOfDay := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0,
0, tz)
+
+ switch opts.mode {
+ case RoundDown:
+ rounded = startOfDay
+ case RoundUp:
+ if opts.CeilIsStrictlyGreater || !t.Equal(startOfDay) {
+ rounded = startOfDay.AddDate(0, 0, 1)
+ } else {
+ rounded = startOfDay
+ }
+ default:
+ nextDay := startOfDay.AddDate(0, 0, 1)
+ rounded = halfRoundPeriod(t, startOfDay, nextDay)
+ }
+
+ default:
+ return 0, fmt.Errorf("%w: unsupported calendar unit",
arrow.ErrNotImplemented)
+ }
+
+ // Convert back to the input unit
+ roundedNanos := rounded.UnixNano()
+ return convertFromNanos(roundedNanos, inputUnit), nil
+}
+
+// Kernel execution functions for temporal rounding
+func FloorTemporalKernel(ctx *exec.KernelCtx, batch *exec.ExecSpan, out
*exec.ExecResult) error {
+ state := ctx.State.(roundTemporalState)
+ state.mode = RoundDown
+ return roundTemporalExec(ctx, batch, out, state)
+}
+
+func CeilTemporalKernel(ctx *exec.KernelCtx, batch *exec.ExecSpan, out
*exec.ExecResult) error {
+ state := ctx.State.(roundTemporalState)
+ state.mode = RoundUp
+ return roundTemporalExec(ctx, batch, out, state)
+}
+
+func RoundTemporalKernel(ctx *exec.KernelCtx, batch *exec.ExecSpan, out
*exec.ExecResult) error {
+ state := ctx.State.(roundTemporalState)
+ state.mode = HalfUp
+ return roundTemporalExec(ctx, batch, out, state)
+}
+
+func roundTemporalExec(ctx *exec.KernelCtx, batch *exec.ExecSpan, out
*exec.ExecResult, state roundTemporalState) error {
+ input := &batch.Values[0].Array
+
+ // Handle date types by converting to timestamp equivalents
+ switch input.Type.ID() {
+ case arrow.DATE32:
+ // Date32 stores days since epoch as int32, treat as
timestamp[s] at midnight
+ fn := func(_ *exec.KernelCtx, days int32, e *error) int32 {
+ // Convert days to seconds (timestamp at midnight UTC)
+ tsSeconds := int64(days) * 86400
+ result, err := roundTimestamp(tsSeconds, arrow.Second,
nil, state)
+ if err != nil {
+ *e = err
+ return 0
+ }
+ // Convert back to days
+ return int32(result / 86400)
+ }
+ return ScalarUnaryNotNull(fn)(ctx, batch, out)
+
+ case arrow.DATE64:
+ // Date64 stores milliseconds since epoch, treat as
timestamp[ms]
+ fn := func(_ *exec.KernelCtx, ms int64, e *error) int64 {
+ result, err := roundTimestamp(ms, arrow.Millisecond,
nil, state)
+ if err != nil {
+ *e = err
+ return 0
+ }
+ return result
+ }
+ return ScalarUnaryNotNull(fn)(ctx, batch, out)
+
+ case arrow.TIME32:
+ // Time32 stores time-of-day in seconds or milliseconds
+ // Rounding wraps at day boundaries (modulo 24 hours)
+ timeType := input.Type.(*arrow.Time32Type)
+ fn := func(_ *exec.KernelCtx, time int32, e *error) int32 {
+ // Convert to int64 for rounding
+ result, err := roundTimestamp(int64(time),
timeType.Unit, nil, state)
+ if err != nil {
+ *e = err
+ return 0
+ }
+ // Wrap at day boundary
+ var dayInUnit int64
+ if timeType.Unit == arrow.Second {
+ dayInUnit = 86400 // 24 hours in seconds
+ } else {
+ dayInUnit = 86400000 // 24 hours in milliseconds
+ }
+ wrapped := result % dayInUnit
+ if wrapped < 0 {
+ wrapped += dayInUnit
+ }
+ return int32(wrapped)
+ }
+ return ScalarUnaryNotNull(fn)(ctx, batch, out)
+
+ case arrow.TIME64:
+ // Time64 stores time-of-day in microseconds or nanoseconds
+ // Rounding wraps at day boundaries (modulo 24 hours)
+ timeType := input.Type.(*arrow.Time64Type)
+ fn := func(_ *exec.KernelCtx, time int64, e *error) int64 {
+ result, err := roundTimestamp(time, timeType.Unit, nil,
state)
+ if err != nil {
+ *e = err
+ return 0
+ }
+ // Wrap at day boundary
+ var dayInUnit int64
+ if timeType.Unit == arrow.Microsecond {
+ dayInUnit = 86400000000 // 24 hours in
microseconds
+ } else {
+ dayInUnit = 86400000000000 // 24 hours in
nanoseconds
+ }
+ wrapped := result % dayInUnit
+ if wrapped < 0 {
+ wrapped += dayInUnit
+ }
+ return wrapped
+ }
+ return ScalarUnaryNotNull(fn)(ctx, batch, out)
+ }
+
+ // Handle timestamp types
+ inputType := input.Type.(arrow.TemporalWithUnit)
+
+ // Extract timezone if present (for timestamp types)
+ var tz *time.Location
+ if tsType, ok := input.Type.(*arrow.TimestampType); ok &&
tsType.TimeZone != "" {
+ var err error
+ tz, err = time.LoadLocation(tsType.TimeZone)
+ if err != nil {
+ return fmt.Errorf("%w: invalid timezone %q: %v",
arrow.ErrInvalid, tsType.TimeZone, err)
+ }
+ }
+
+ fn := func(_ *exec.KernelCtx, ts int64, e *error) int64 {
+ result, err := roundTimestamp(ts, inputType.TimeUnit(), tz,
state)
+ if err != nil {
+ *e = err
+ return 0
+ }
+ return result
+ }
+
+ switch inputType.TimeUnit() {
+ case arrow.Second, arrow.Millisecond, arrow.Microsecond,
arrow.Nanosecond:
+ return ScalarUnaryNotNull(fn)(ctx, batch, out)
+ default:
+ return fmt.Errorf("%w: unsupported time unit",
arrow.ErrNotImplemented)
+ }
+}
+
+type timestampUnitMatcher struct {
+ unit arrow.TimeUnit
+}
+
+func (m *timestampUnitMatcher) Matches(typ arrow.DataType) bool {
+ if ts, ok := typ.(*arrow.TimestampType); ok {
+ return ts.Unit == m.unit
+ }
+ return false
+}
+
+func (m *timestampUnitMatcher) String() string {
+ return "timestamp(unit=" + m.unit.String() + ")"
+}
+
+func (m *timestampUnitMatcher) Equals(other exec.TypeMatcher) bool {
+ if o, ok := other.(*timestampUnitMatcher); ok {
+ return m.unit == o.unit
+ }
+ return false
+}
Review Comment:
we already have a matcher `TimestampTypeUnit` in `exec/kernel.go` that you
can use instead of this
##########
arrow/compute/internal/kernels/rounding.go:
##########
@@ -807,3 +808,681 @@ func FixedRoundDecimalExec[T decimal128.Num |
decimal256.Num](mode RoundMode) ex
}
panic("should never get here")
}
+
+// RoundTemporalUnit represents units supported for temporal rounding
+type RoundTemporalUnit int8
+
+const (
+ RoundTemporalYear RoundTemporalUnit = iota
+ RoundTemporalQuarter
+ RoundTemporalMonth
+ RoundTemporalWeek
+ RoundTemporalDay
+ RoundTemporalHour
+ RoundTemporalMinute
+ RoundTemporalSecond
+ RoundTemporalMillisecond
+ RoundTemporalMicrosecond
+ RoundTemporalNanosecond
+)
+
+// RoundTemporalOptions provides configuration for temporal rounding operations
+type RoundTemporalOptions struct {
+ // Multiple is the number of units to round to. Must be positive.
+ Multiple int64
+ // Unit is the rounding unit (day, hour, etc.)
+ Unit RoundTemporalUnit
+ // WeekStartsMonday determines the start of the week for week-based
rounding
+ WeekStartsMonday bool
+ // CeilIsStrictlyGreater: if true, ceil returns a value strictly
greater than input
+ CeilIsStrictlyGreater bool
+ // CalendarBasedOrigin: if true, use calendar units as origin (e.g.,
start of day for hours)
+ CalendarBasedOrigin bool
+}
+
+func (RoundTemporalOptions) TypeName() string { return "RoundTemporalOptions" }
+
+type roundTemporalState struct {
+ RoundTemporalOptions
+ mode RoundMode
+
+ // Pre-calculated values to avoid repeated computation
+ unitNanos int64 // Duration of the unit in nanoseconds
+ roundingInterval int64 // unitNanos * Multiple
+ isSubDay bool // true if this is a sub-day unit (can use fast
path)
+ useCalendarOrigin bool // true if using calendar-based origin
+}
+
+func InitRoundTemporalState(_ *exec.KernelCtx, args exec.KernelInitArgs)
(exec.KernelState, error) {
+ var rs roundTemporalState
+
+ opts, ok := args.Options.(*RoundTemporalOptions)
+ if ok {
+ rs.RoundTemporalOptions = *opts
+ } else {
+ if rs.RoundTemporalOptions, ok =
args.Options.(RoundTemporalOptions); !ok {
+ return nil, fmt.Errorf("%w: attempted to initialize
kernel state from invalid function options",
+ arrow.ErrInvalid)
+ }
+ }
+
+ if rs.Multiple <= 0 {
+ return nil, fmt.Errorf("%w: rounding multiple must be
positive", arrow.ErrInvalid)
+ }
+
+ // Pre-calculate constants for this rounding operation
+ rs.unitNanos, rs.isSubDay = unitInNanos(rs.Unit)
+ if rs.isSubDay {
+ rs.roundingInterval = rs.unitNanos * rs.Multiple
+ rs.useCalendarOrigin = rs.CalendarBasedOrigin && rs.Unit <=
RoundTemporalDay
+ }
+
+ return rs, nil
+}
+
+// unitInNanos returns (nanoseconds, hasFixedDuration) for a temporal unit.
+// Returns false for calendar units with variable durations (year, quarter,
month, week).
+func unitInNanos(unit RoundTemporalUnit) (int64, bool) {
+ switch unit {
+ case RoundTemporalNanosecond:
+ return 1, true
+ case RoundTemporalMicrosecond:
+ return 1000, true
+ case RoundTemporalMillisecond:
+ return 1000000, true
+ case RoundTemporalSecond:
+ return 1000000000, true
+ case RoundTemporalMinute:
+ return 60 * 1000000000, true
+ case RoundTemporalHour:
+ return 3600 * 1000000000, true
+ case RoundTemporalDay:
+ return 86400 * 1000000000, true
+ default:
+ return 0, false
+ }
+}
+
+// roundTimestamp rounds a timestamp value according to the specified options.
+// tz specifies the timezone for calendar-aware rounding (nil defaults to UTC).
+func roundTimestamp(ts int64, inputUnit arrow.TimeUnit, tz *time.Location,
opts roundTemporalState) (int64, error) {
+ if tz == nil {
+ tz = time.UTC
+ }
+
+ // Calendar units with variable duration (year, quarter, month, week)
require date arithmetic
+ if !opts.isSubDay {
+ tsNanos := convertToNanos(ts, inputUnit)
+ return roundTimestampCalendar(tsNanos, inputUnit, tz, opts)
+ }
+
+ // Day rounding with timezone requires calendar arithmetic (days vary:
23/24/25 hours due to DST)
+ isUTC := tz == time.UTC || tz.String() == "UTC"
+ if !isUTC && opts.Unit == RoundTemporalDay {
+ tsNanos := convertToNanos(ts, inputUnit)
+ return roundTimestampCalendar(tsNanos, inputUnit, tz, opts)
+ }
+
+ // Sub-day units (hour, minute, second, etc.) use fixed-duration
arithmetic
+ // Fast path: round directly in input unit if possible (no origin,
compatible units)
+ if canRoundInInputUnit(inputUnit, opts.unitNanos) &&
!opts.useCalendarOrigin {
+ intervalInInputUnit := opts.roundingInterval /
unitScaleFactor(inputUnit)
+ rounded := roundToMultipleInt64(ts, intervalInInputUnit,
opts.mode, opts.CeilIsStrictlyGreater)
+ return rounded, nil
+ }
+
+ // Slow path: convert to nanoseconds for calendar origin or
incompatible units
+ tsNanos := convertToNanos(ts, inputUnit)
+
+ var origin int64 = 0
+ if opts.useCalendarOrigin {
+ // Calendar origin: round relative to start of day
(timezone-aware if tz != nil)
+ if tz != nil {
+ t := time.Unix(0, tsNanos).In(tz)
+ startOfDay := time.Date(t.Year(), t.Month(), t.Day(),
0, 0, 0, 0, tz)
+ origin = startOfDay.UnixNano()
+ } else {
+ origin = tsNanos
+ }
+ }
+
+ adjusted := tsNanos - origin
+ rounded := roundToMultipleInt64(adjusted, opts.roundingInterval,
opts.mode, opts.CeilIsStrictlyGreater)
+ result := origin + rounded
+
+ return convertFromNanos(result, inputUnit), nil
+}
+
+// unitScaleFactor returns nanoseconds per unit for the given time unit
+func unitScaleFactor(unit arrow.TimeUnit) int64 {
+ switch unit {
+ case arrow.Second:
+ return 1_000_000_000
+ case arrow.Millisecond:
+ return 1_000_000
+ case arrow.Microsecond:
+ return 1_000
+ case arrow.Nanosecond:
+ return 1
+ default:
+ return 1
+ }
+}
+
+// canRoundInInputUnit checks if rounding can be done in the input unit
+// without converting to nanoseconds (true when rounding interval is evenly
divisible).
+func canRoundInInputUnit(inputUnit arrow.TimeUnit, roundingIntervalNanos
int64) bool {
+ return roundingIntervalNanos%unitScaleFactor(inputUnit) == 0
+}
+
+// convertToNanos converts a timestamp value to nanoseconds
+func convertToNanos(ts int64, unit arrow.TimeUnit) int64 {
+ return ts * unitScaleFactor(unit)
+}
+
+// convertFromNanos converts a nanosecond timestamp to the specified unit
+func convertFromNanos(nanos int64, unit arrow.TimeUnit) int64 {
+ return nanos / unitScaleFactor(unit)
+}
+
+func roundToMultipleInt64(value, multiple int64, mode RoundMode, strictCeil
bool) int64 {
+ if multiple == 0 || value%multiple == 0 {
+ if strictCeil && mode == RoundUp {
+ return value + multiple
+ }
+ return value
+ }
+
+ quotient := value / multiple
+ remainder := value % multiple
+
+ switch mode {
+ case RoundDown:
+ if remainder < 0 {
+ return (quotient - 1) * multiple
+ }
+ return quotient * multiple
+ case RoundUp:
+ if remainder > 0 || (strictCeil && remainder == 0) {
+ return (quotient + 1) * multiple
+ }
+ if remainder < 0 {
+ return quotient * multiple
+ }
+ return (quotient + 1) * multiple
+ case HalfUp, HalfDown, HalfToEven:
+ half := multiple / 2
+ absRemainder := remainder
+ if absRemainder < 0 {
+ absRemainder = -absRemainder
+ }
+
+ if absRemainder < half {
+ return quotient * multiple
+ } else if absRemainder > half {
+ if remainder > 0 {
+ return (quotient + 1) * multiple
+ }
+ return (quotient - 1) * multiple
+ } else {
+ // Exactly on the halfway point
+ switch mode {
+ case HalfDown:
+ if remainder > 0 {
+ return quotient * multiple
+ }
+ return (quotient - 1) * multiple
+ case HalfUp:
+ if remainder > 0 {
+ return (quotient + 1) * multiple
+ }
+ return quotient * multiple
+ case HalfToEven:
+ if quotient%2 == 0 {
+ return quotient * multiple
+ }
+ if remainder > 0 {
+ return (quotient + 1) * multiple
+ }
+ return (quotient - 1) * multiple
+ }
+ }
+ }
+ return quotient * multiple
+}
+
+// halfRoundPeriod performs half-rounding by finding the midpoint between
period start and end
+func halfRoundPeriod(t, periodStart, periodEnd time.Time) time.Time {
+ midPoint := periodStart.Add(periodEnd.Sub(periodStart) / 2)
+ if t.Before(midPoint) {
+ return periodStart
+ }
+ return periodEnd
+}
+
+// roundTimestampCalendar handles calendar-based rounding (year, quarter,
month, week, day).
+// Requires date arithmetic for variable-length periods and timezone-aware
boundaries.
+func roundTimestampCalendar(tsNanos int64, inputUnit arrow.TimeUnit, tz
*time.Location, opts roundTemporalState) (int64, error) {
+ // Convert to time.Time for calendar operations in the specified
timezone
+ secs := tsNanos / 1000000000
+ nanos := tsNanos % 1000000000
+ t := time.Unix(secs, nanos).In(tz)
+
+ var rounded time.Time
+
+ switch opts.Unit {
+ case RoundTemporalYear:
+ year := t.Year()
+ roundedYear := (year / int(opts.Multiple)) * int(opts.Multiple)
+ switch opts.mode {
+ case RoundDown:
+ rounded = time.Date(roundedYear, 1, 1, 0, 0, 0, 0, tz)
+ case RoundUp:
+ periodStart := time.Date(roundedYear, 1, 1, 0, 0, 0, 0,
tz)
+ if opts.CeilIsStrictlyGreater || !t.Equal(periodStart) {
+ roundedYear += int(opts.Multiple)
+ rounded = time.Date(roundedYear, 1, 1, 0, 0, 0,
0, tz)
+ } else {
+ rounded = periodStart
+ }
+ default:
+ yearStart := time.Date(roundedYear, 1, 1, 0, 0, 0, 0,
tz)
+ nextYear := roundedYear + int(opts.Multiple)
+ yearEnd := time.Date(nextYear, 1, 1, 0, 0, 0, 0, tz)
+ rounded = halfRoundPeriod(t, yearStart, yearEnd)
+ }
+
+ case RoundTemporalQuarter:
+ // Q1=Jan-Mar, Q2=Apr-Jun, Q3=Jul-Sep, Q4=Oct-Dec
+ month := int(t.Month())
+ year := t.Year()
+ totalQuarters := year*4 + (month-1)/3
+ roundedQuarters := (totalQuarters / int(opts.Multiple)) *
int(opts.Multiple)
+ roundedYear := roundedQuarters / 4
+ roundedQuarter := roundedQuarters % 4
+ roundedMonth := roundedQuarter*3 + 1 // First month of the
quarter
+
+ switch opts.mode {
+ case RoundDown:
+ rounded = time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ case RoundUp:
+ periodStart := time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ if opts.CeilIsStrictlyGreater || !t.Equal(periodStart) {
+ roundedQuarters += int(opts.Multiple)
+ roundedYear = roundedQuarters / 4
+ roundedQuarter = roundedQuarters % 4
+ roundedMonth = roundedQuarter*3 + 1
+ rounded = time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ } else {
+ rounded = periodStart
+ }
+ default:
+ quarterStart := time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ nextQuarterNum := roundedQuarters + int(opts.Multiple)
+ nextYear := nextQuarterNum / 4
+ nextQuarter := nextQuarterNum % 4
+ nextMonth := nextQuarter*3 + 1
+ quarterEnd := time.Date(nextYear,
time.Month(nextMonth), 1, 0, 0, 0, 0, tz)
+ rounded = halfRoundPeriod(t, quarterStart, quarterEnd)
+ }
+
+ case RoundTemporalMonth:
+ month := int(t.Month())
+ year := t.Year()
+ totalMonths := year*12 + month - 1
+ roundedMonths := (totalMonths / int(opts.Multiple)) *
int(opts.Multiple)
+ roundedYear := roundedMonths / 12
+ roundedMonth := (roundedMonths % 12) + 1
+
+ switch opts.mode {
+ case RoundDown:
+ rounded = time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ case RoundUp:
+ periodStart := time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ if opts.CeilIsStrictlyGreater || !t.Equal(periodStart) {
+ roundedMonths += int(opts.Multiple)
+ roundedYear = roundedMonths / 12
+ roundedMonth = (roundedMonths % 12) + 1
+ rounded = time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ } else {
+ rounded = periodStart
+ }
+ default:
+ monthStart := time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ nextMonthNum := roundedMonths + int(opts.Multiple)
+ nextYear := nextMonthNum / 12
+ nextMonth := (nextMonthNum % 12) + 1
+ monthEnd := time.Date(nextYear, time.Month(nextMonth),
1, 0, 0, 0, 0, tz)
+ rounded = halfRoundPeriod(t, monthStart, monthEnd)
+ }
+
+ case RoundTemporalWeek:
+ weekday := int(t.Weekday())
+ if opts.WeekStartsMonday {
+ weekday = (weekday + 6) % 7
+ }
+ startOfWeek := t.AddDate(0, 0, -weekday)
+ startOfWeek = time.Date(startOfWeek.Year(),
startOfWeek.Month(), startOfWeek.Day(), 0, 0, 0, 0, tz)
+
+ // Calculate N-week periods from epoch for Multiple > 1
+ epochInTz := time.Unix(0, 0).In(tz)
+ epochWeekday := int(epochInTz.Weekday())
+ if opts.WeekStartsMonday {
+ epochWeekday = (epochWeekday + 6) % 7
+ }
+ epochWeekStart := epochInTz.AddDate(0, 0, -epochWeekday)
+ epochWeekStart = time.Date(epochWeekStart.Year(),
epochWeekStart.Month(), epochWeekStart.Day(), 0, 0, 0, 0, tz)
+
+ daysSinceEpochWeek :=
int(startOfWeek.Sub(epochWeekStart).Hours() / 24)
+ weeksSinceEpoch := daysSinceEpochWeek / 7
+ roundedWeeks := (weeksSinceEpoch / int(opts.Multiple)) *
int(opts.Multiple)
+ roundedWeekStart := epochWeekStart.AddDate(0, 0, roundedWeeks*7)
+
+ switch opts.mode {
+ case RoundDown:
+ rounded = roundedWeekStart
+ case RoundUp:
+ if opts.CeilIsStrictlyGreater ||
!t.Equal(roundedWeekStart) {
+ rounded = roundedWeekStart.AddDate(0, 0,
7*int(opts.Multiple))
+ } else {
+ rounded = roundedWeekStart
+ }
+ default:
+ weekEnd := roundedWeekStart.AddDate(0, 0,
7*int(opts.Multiple))
+ rounded = halfRoundPeriod(t, roundedWeekStart, weekEnd)
+ }
+
+ case RoundTemporalDay:
+ startOfDay := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0,
0, tz)
+
+ switch opts.mode {
+ case RoundDown:
+ rounded = startOfDay
+ case RoundUp:
+ if opts.CeilIsStrictlyGreater || !t.Equal(startOfDay) {
+ rounded = startOfDay.AddDate(0, 0, 1)
+ } else {
+ rounded = startOfDay
+ }
+ default:
+ nextDay := startOfDay.AddDate(0, 0, 1)
+ rounded = halfRoundPeriod(t, startOfDay, nextDay)
+ }
+
+ default:
+ return 0, fmt.Errorf("%w: unsupported calendar unit",
arrow.ErrNotImplemented)
+ }
+
+ // Convert back to the input unit
+ roundedNanos := rounded.UnixNano()
+ return convertFromNanos(roundedNanos, inputUnit), nil
+}
+
+// Kernel execution functions for temporal rounding
+func FloorTemporalKernel(ctx *exec.KernelCtx, batch *exec.ExecSpan, out
*exec.ExecResult) error {
+ state := ctx.State.(roundTemporalState)
+ state.mode = RoundDown
+ return roundTemporalExec(ctx, batch, out, state)
+}
+
+func CeilTemporalKernel(ctx *exec.KernelCtx, batch *exec.ExecSpan, out
*exec.ExecResult) error {
+ state := ctx.State.(roundTemporalState)
+ state.mode = RoundUp
+ return roundTemporalExec(ctx, batch, out, state)
+}
+
+func RoundTemporalKernel(ctx *exec.KernelCtx, batch *exec.ExecSpan, out
*exec.ExecResult) error {
+ state := ctx.State.(roundTemporalState)
+ state.mode = HalfUp
+ return roundTemporalExec(ctx, batch, out, state)
+}
+
+func roundTemporalExec(ctx *exec.KernelCtx, batch *exec.ExecSpan, out
*exec.ExecResult, state roundTemporalState) error {
+ input := &batch.Values[0].Array
+
+ // Handle date types by converting to timestamp equivalents
+ switch input.Type.ID() {
+ case arrow.DATE32:
+ // Date32 stores days since epoch as int32, treat as
timestamp[s] at midnight
+ fn := func(_ *exec.KernelCtx, days int32, e *error) int32 {
+ // Convert days to seconds (timestamp at midnight UTC)
+ tsSeconds := int64(days) * 86400
+ result, err := roundTimestamp(tsSeconds, arrow.Second,
nil, state)
+ if err != nil {
+ *e = err
+ return 0
+ }
+ // Convert back to days
+ return int32(result / 86400)
+ }
+ return ScalarUnaryNotNull(fn)(ctx, batch, out)
+
+ case arrow.DATE64:
+ // Date64 stores milliseconds since epoch, treat as
timestamp[ms]
+ fn := func(_ *exec.KernelCtx, ms int64, e *error) int64 {
+ result, err := roundTimestamp(ms, arrow.Millisecond,
nil, state)
+ if err != nil {
+ *e = err
+ return 0
+ }
+ return result
+ }
+ return ScalarUnaryNotNull(fn)(ctx, batch, out)
+
+ case arrow.TIME32:
+ // Time32 stores time-of-day in seconds or milliseconds
+ // Rounding wraps at day boundaries (modulo 24 hours)
+ timeType := input.Type.(*arrow.Time32Type)
+ fn := func(_ *exec.KernelCtx, time int32, e *error) int32 {
+ // Convert to int64 for rounding
+ result, err := roundTimestamp(int64(time),
timeType.Unit, nil, state)
+ if err != nil {
+ *e = err
+ return 0
+ }
+ // Wrap at day boundary
+ var dayInUnit int64
+ if timeType.Unit == arrow.Second {
+ dayInUnit = 86400 // 24 hours in seconds
+ } else {
+ dayInUnit = 86400000 // 24 hours in milliseconds
+ }
+ wrapped := result % dayInUnit
+ if wrapped < 0 {
+ wrapped += dayInUnit
+ }
+ return int32(wrapped)
+ }
+ return ScalarUnaryNotNull(fn)(ctx, batch, out)
+
+ case arrow.TIME64:
+ // Time64 stores time-of-day in microseconds or nanoseconds
+ // Rounding wraps at day boundaries (modulo 24 hours)
+ timeType := input.Type.(*arrow.Time64Type)
+ fn := func(_ *exec.KernelCtx, time int64, e *error) int64 {
+ result, err := roundTimestamp(time, timeType.Unit, nil,
state)
+ if err != nil {
+ *e = err
+ return 0
+ }
+ // Wrap at day boundary
+ var dayInUnit int64
+ if timeType.Unit == arrow.Microsecond {
+ dayInUnit = 86400000000 // 24 hours in
microseconds
+ } else {
+ dayInUnit = 86400000000000 // 24 hours in
nanoseconds
+ }
+ wrapped := result % dayInUnit
+ if wrapped < 0 {
+ wrapped += dayInUnit
+ }
+ return wrapped
+ }
+ return ScalarUnaryNotNull(fn)(ctx, batch, out)
+ }
+
+ // Handle timestamp types
+ inputType := input.Type.(arrow.TemporalWithUnit)
+
+ // Extract timezone if present (for timestamp types)
+ var tz *time.Location
+ if tsType, ok := input.Type.(*arrow.TimestampType); ok &&
tsType.TimeZone != "" {
+ var err error
+ tz, err = time.LoadLocation(tsType.TimeZone)
+ if err != nil {
+ return fmt.Errorf("%w: invalid timezone %q: %v",
arrow.ErrInvalid, tsType.TimeZone, err)
+ }
+ }
+
+ fn := func(_ *exec.KernelCtx, ts int64, e *error) int64 {
+ result, err := roundTimestamp(ts, inputType.TimeUnit(), tz,
state)
+ if err != nil {
+ *e = err
+ return 0
+ }
+ return result
+ }
+
+ switch inputType.TimeUnit() {
+ case arrow.Second, arrow.Millisecond, arrow.Microsecond,
arrow.Nanosecond:
+ return ScalarUnaryNotNull(fn)(ctx, batch, out)
+ default:
+ return fmt.Errorf("%w: unsupported time unit",
arrow.ErrNotImplemented)
+ }
+}
+
+type timestampUnitMatcher struct {
+ unit arrow.TimeUnit
+}
+
+func (m *timestampUnitMatcher) Matches(typ arrow.DataType) bool {
+ if ts, ok := typ.(*arrow.TimestampType); ok {
+ return ts.Unit == m.unit
+ }
+ return false
+}
+
+func (m *timestampUnitMatcher) String() string {
+ return "timestamp(unit=" + m.unit.String() + ")"
+}
+
+func (m *timestampUnitMatcher) Equals(other exec.TypeMatcher) bool {
+ if o, ok := other.(*timestampUnitMatcher); ok {
+ return m.unit == o.unit
+ }
+ return false
+}
+
+type dateTypeMatcher struct {
+ dateTypeID arrow.Type
+}
+
+func (m *dateTypeMatcher) Matches(typ arrow.DataType) bool {
+ return typ.ID() == m.dateTypeID
+}
+
+func (m *dateTypeMatcher) String() string {
+ if m.dateTypeID == arrow.DATE32 {
+ return "date32"
+ }
+ return "date64"
+}
+
+func (m *dateTypeMatcher) Equals(other exec.TypeMatcher) bool {
+ if o, ok := other.(*dateTypeMatcher); ok {
+ return m.dateTypeID == o.dateTypeID
+ }
+ return false
+}
+
+type timeTypeMatcher struct {
+ timeTypeID arrow.Type
+ unit arrow.TimeUnit
+}
+
+func (m *timeTypeMatcher) Matches(typ arrow.DataType) bool {
+ if typ.ID() != m.timeTypeID {
+ return false
+ }
+ switch t := typ.(type) {
+ case *arrow.Time32Type:
+ return t.Unit == m.unit
+ case *arrow.Time64Type:
+ return t.Unit == m.unit
+ }
+ return false
+}
+
+func (m *timeTypeMatcher) String() string {
+ if m.timeTypeID == arrow.TIME32 {
+ return fmt.Sprintf("time32[%s]", m.unit)
+ }
+ return fmt.Sprintf("time64[%s]", m.unit)
+}
+
+func (m *timeTypeMatcher) Equals(other exec.TypeMatcher) bool {
+ if o, ok := other.(*timeTypeMatcher); ok {
+ return m.timeTypeID == o.timeTypeID && m.unit == o.unit
+ }
+ return false
+}
Review Comment:
This already exists via `Time32TypeUnit(....)` and `Time64TypeUnit(...)`
etc. in `exec/kernel.go`
##########
arrow/compute/internal/kernels/rounding.go:
##########
@@ -807,3 +808,681 @@ func FixedRoundDecimalExec[T decimal128.Num |
decimal256.Num](mode RoundMode) ex
}
panic("should never get here")
}
+
+// RoundTemporalUnit represents units supported for temporal rounding
+type RoundTemporalUnit int8
+
+const (
+ RoundTemporalYear RoundTemporalUnit = iota
+ RoundTemporalQuarter
+ RoundTemporalMonth
+ RoundTemporalWeek
+ RoundTemporalDay
+ RoundTemporalHour
+ RoundTemporalMinute
+ RoundTemporalSecond
+ RoundTemporalMillisecond
+ RoundTemporalMicrosecond
+ RoundTemporalNanosecond
+)
+
+// RoundTemporalOptions provides configuration for temporal rounding operations
+type RoundTemporalOptions struct {
+ // Multiple is the number of units to round to. Must be positive.
+ Multiple int64
+ // Unit is the rounding unit (day, hour, etc.)
+ Unit RoundTemporalUnit
+ // WeekStartsMonday determines the start of the week for week-based
rounding
+ WeekStartsMonday bool
+ // CeilIsStrictlyGreater: if true, ceil returns a value strictly
greater than input
+ CeilIsStrictlyGreater bool
+ // CalendarBasedOrigin: if true, use calendar units as origin (e.g.,
start of day for hours)
+ CalendarBasedOrigin bool
+}
+
+func (RoundTemporalOptions) TypeName() string { return "RoundTemporalOptions" }
+
+type roundTemporalState struct {
+ RoundTemporalOptions
+ mode RoundMode
+
+ // Pre-calculated values to avoid repeated computation
+ unitNanos int64 // Duration of the unit in nanoseconds
+ roundingInterval int64 // unitNanos * Multiple
+ isSubDay bool // true if this is a sub-day unit (can use fast
path)
+ useCalendarOrigin bool // true if using calendar-based origin
+}
+
+func InitRoundTemporalState(_ *exec.KernelCtx, args exec.KernelInitArgs)
(exec.KernelState, error) {
+ var rs roundTemporalState
+
+ opts, ok := args.Options.(*RoundTemporalOptions)
+ if ok {
+ rs.RoundTemporalOptions = *opts
+ } else {
+ if rs.RoundTemporalOptions, ok =
args.Options.(RoundTemporalOptions); !ok {
+ return nil, fmt.Errorf("%w: attempted to initialize
kernel state from invalid function options",
+ arrow.ErrInvalid)
+ }
+ }
+
+ if rs.Multiple <= 0 {
+ return nil, fmt.Errorf("%w: rounding multiple must be
positive", arrow.ErrInvalid)
+ }
+
+ // Pre-calculate constants for this rounding operation
+ rs.unitNanos, rs.isSubDay = unitInNanos(rs.Unit)
+ if rs.isSubDay {
+ rs.roundingInterval = rs.unitNanos * rs.Multiple
+ rs.useCalendarOrigin = rs.CalendarBasedOrigin && rs.Unit <=
RoundTemporalDay
+ }
+
+ return rs, nil
+}
+
+// unitInNanos returns (nanoseconds, hasFixedDuration) for a temporal unit.
+// Returns false for calendar units with variable durations (year, quarter,
month, week).
+func unitInNanos(unit RoundTemporalUnit) (int64, bool) {
+ switch unit {
+ case RoundTemporalNanosecond:
+ return 1, true
+ case RoundTemporalMicrosecond:
+ return 1000, true
+ case RoundTemporalMillisecond:
+ return 1000000, true
+ case RoundTemporalSecond:
+ return 1000000000, true
+ case RoundTemporalMinute:
+ return 60 * 1000000000, true
+ case RoundTemporalHour:
+ return 3600 * 1000000000, true
+ case RoundTemporalDay:
+ return 86400 * 1000000000, true
+ default:
+ return 0, false
+ }
+}
+
+// roundTimestamp rounds a timestamp value according to the specified options.
+// tz specifies the timezone for calendar-aware rounding (nil defaults to UTC).
+func roundTimestamp(ts int64, inputUnit arrow.TimeUnit, tz *time.Location,
opts roundTemporalState) (int64, error) {
+ if tz == nil {
+ tz = time.UTC
+ }
+
+ // Calendar units with variable duration (year, quarter, month, week)
require date arithmetic
+ if !opts.isSubDay {
+ tsNanos := convertToNanos(ts, inputUnit)
+ return roundTimestampCalendar(tsNanos, inputUnit, tz, opts)
+ }
+
+ // Day rounding with timezone requires calendar arithmetic (days vary:
23/24/25 hours due to DST)
+ isUTC := tz == time.UTC || tz.String() == "UTC"
+ if !isUTC && opts.Unit == RoundTemporalDay {
+ tsNanos := convertToNanos(ts, inputUnit)
+ return roundTimestampCalendar(tsNanos, inputUnit, tz, opts)
+ }
+
+ // Sub-day units (hour, minute, second, etc.) use fixed-duration
arithmetic
+ // Fast path: round directly in input unit if possible (no origin,
compatible units)
+ if canRoundInInputUnit(inputUnit, opts.unitNanos) &&
!opts.useCalendarOrigin {
+ intervalInInputUnit := opts.roundingInterval /
unitScaleFactor(inputUnit)
+ rounded := roundToMultipleInt64(ts, intervalInInputUnit,
opts.mode, opts.CeilIsStrictlyGreater)
+ return rounded, nil
+ }
+
+ // Slow path: convert to nanoseconds for calendar origin or
incompatible units
+ tsNanos := convertToNanos(ts, inputUnit)
+
+ var origin int64 = 0
+ if opts.useCalendarOrigin {
+ // Calendar origin: round relative to start of day
(timezone-aware if tz != nil)
+ if tz != nil {
+ t := time.Unix(0, tsNanos).In(tz)
+ startOfDay := time.Date(t.Year(), t.Month(), t.Day(),
0, 0, 0, 0, tz)
+ origin = startOfDay.UnixNano()
+ } else {
+ origin = tsNanos
+ }
+ }
+
+ adjusted := tsNanos - origin
+ rounded := roundToMultipleInt64(adjusted, opts.roundingInterval,
opts.mode, opts.CeilIsStrictlyGreater)
+ result := origin + rounded
+
+ return convertFromNanos(result, inputUnit), nil
+}
+
+// unitScaleFactor returns nanoseconds per unit for the given time unit
+func unitScaleFactor(unit arrow.TimeUnit) int64 {
+ switch unit {
+ case arrow.Second:
+ return 1_000_000_000
+ case arrow.Millisecond:
+ return 1_000_000
+ case arrow.Microsecond:
+ return 1_000
+ case arrow.Nanosecond:
+ return 1
+ default:
+ return 1
+ }
+}
+
+// canRoundInInputUnit checks if rounding can be done in the input unit
+// without converting to nanoseconds (true when rounding interval is evenly
divisible).
+func canRoundInInputUnit(inputUnit arrow.TimeUnit, roundingIntervalNanos
int64) bool {
+ return roundingIntervalNanos%unitScaleFactor(inputUnit) == 0
+}
+
+// convertToNanos converts a timestamp value to nanoseconds
+func convertToNanos(ts int64, unit arrow.TimeUnit) int64 {
+ return ts * unitScaleFactor(unit)
Review Comment:
```suggestion
return ts * int64(unit.Multiplier())
```
##########
arrow/compute/internal/kernels/rounding.go:
##########
@@ -807,3 +808,681 @@ func FixedRoundDecimalExec[T decimal128.Num |
decimal256.Num](mode RoundMode) ex
}
panic("should never get here")
}
+
+// RoundTemporalUnit represents units supported for temporal rounding
+type RoundTemporalUnit int8
+
+const (
+ RoundTemporalYear RoundTemporalUnit = iota
+ RoundTemporalQuarter
+ RoundTemporalMonth
+ RoundTemporalWeek
+ RoundTemporalDay
+ RoundTemporalHour
+ RoundTemporalMinute
+ RoundTemporalSecond
+ RoundTemporalMillisecond
+ RoundTemporalMicrosecond
+ RoundTemporalNanosecond
+)
+
+// RoundTemporalOptions provides configuration for temporal rounding operations
+type RoundTemporalOptions struct {
+ // Multiple is the number of units to round to. Must be positive.
+ Multiple int64
+ // Unit is the rounding unit (day, hour, etc.)
+ Unit RoundTemporalUnit
+ // WeekStartsMonday determines the start of the week for week-based
rounding
+ WeekStartsMonday bool
+ // CeilIsStrictlyGreater: if true, ceil returns a value strictly
greater than input
+ CeilIsStrictlyGreater bool
+ // CalendarBasedOrigin: if true, use calendar units as origin (e.g.,
start of day for hours)
+ CalendarBasedOrigin bool
+}
+
+func (RoundTemporalOptions) TypeName() string { return "RoundTemporalOptions" }
+
+type roundTemporalState struct {
+ RoundTemporalOptions
+ mode RoundMode
+
+ // Pre-calculated values to avoid repeated computation
+ unitNanos int64 // Duration of the unit in nanoseconds
+ roundingInterval int64 // unitNanos * Multiple
+ isSubDay bool // true if this is a sub-day unit (can use fast
path)
+ useCalendarOrigin bool // true if using calendar-based origin
+}
+
+func InitRoundTemporalState(_ *exec.KernelCtx, args exec.KernelInitArgs)
(exec.KernelState, error) {
+ var rs roundTemporalState
+
+ opts, ok := args.Options.(*RoundTemporalOptions)
+ if ok {
+ rs.RoundTemporalOptions = *opts
+ } else {
+ if rs.RoundTemporalOptions, ok =
args.Options.(RoundTemporalOptions); !ok {
+ return nil, fmt.Errorf("%w: attempted to initialize
kernel state from invalid function options",
+ arrow.ErrInvalid)
+ }
+ }
+
+ if rs.Multiple <= 0 {
+ return nil, fmt.Errorf("%w: rounding multiple must be
positive", arrow.ErrInvalid)
+ }
+
+ // Pre-calculate constants for this rounding operation
+ rs.unitNanos, rs.isSubDay = unitInNanos(rs.Unit)
+ if rs.isSubDay {
+ rs.roundingInterval = rs.unitNanos * rs.Multiple
+ rs.useCalendarOrigin = rs.CalendarBasedOrigin && rs.Unit <=
RoundTemporalDay
+ }
+
+ return rs, nil
+}
+
+// unitInNanos returns (nanoseconds, hasFixedDuration) for a temporal unit.
+// Returns false for calendar units with variable durations (year, quarter,
month, week).
+func unitInNanos(unit RoundTemporalUnit) (int64, bool) {
+ switch unit {
+ case RoundTemporalNanosecond:
+ return 1, true
+ case RoundTemporalMicrosecond:
+ return 1000, true
+ case RoundTemporalMillisecond:
+ return 1000000, true
+ case RoundTemporalSecond:
+ return 1000000000, true
+ case RoundTemporalMinute:
+ return 60 * 1000000000, true
+ case RoundTemporalHour:
+ return 3600 * 1000000000, true
+ case RoundTemporalDay:
+ return 86400 * 1000000000, true
+ default:
+ return 0, false
+ }
+}
+
+// roundTimestamp rounds a timestamp value according to the specified options.
+// tz specifies the timezone for calendar-aware rounding (nil defaults to UTC).
+func roundTimestamp(ts int64, inputUnit arrow.TimeUnit, tz *time.Location,
opts roundTemporalState) (int64, error) {
+ if tz == nil {
+ tz = time.UTC
+ }
+
+ // Calendar units with variable duration (year, quarter, month, week)
require date arithmetic
+ if !opts.isSubDay {
+ tsNanos := convertToNanos(ts, inputUnit)
+ return roundTimestampCalendar(tsNanos, inputUnit, tz, opts)
+ }
+
+ // Day rounding with timezone requires calendar arithmetic (days vary:
23/24/25 hours due to DST)
+ isUTC := tz == time.UTC || tz.String() == "UTC"
+ if !isUTC && opts.Unit == RoundTemporalDay {
+ tsNanos := convertToNanos(ts, inputUnit)
+ return roundTimestampCalendar(tsNanos, inputUnit, tz, opts)
+ }
+
+ // Sub-day units (hour, minute, second, etc.) use fixed-duration
arithmetic
+ // Fast path: round directly in input unit if possible (no origin,
compatible units)
+ if canRoundInInputUnit(inputUnit, opts.unitNanos) &&
!opts.useCalendarOrigin {
+ intervalInInputUnit := opts.roundingInterval /
unitScaleFactor(inputUnit)
+ rounded := roundToMultipleInt64(ts, intervalInInputUnit,
opts.mode, opts.CeilIsStrictlyGreater)
+ return rounded, nil
+ }
+
+ // Slow path: convert to nanoseconds for calendar origin or
incompatible units
+ tsNanos := convertToNanos(ts, inputUnit)
+
+ var origin int64 = 0
+ if opts.useCalendarOrigin {
+ // Calendar origin: round relative to start of day
(timezone-aware if tz != nil)
+ if tz != nil {
+ t := time.Unix(0, tsNanos).In(tz)
+ startOfDay := time.Date(t.Year(), t.Month(), t.Day(),
0, 0, 0, 0, tz)
+ origin = startOfDay.UnixNano()
+ } else {
+ origin = tsNanos
+ }
+ }
+
+ adjusted := tsNanos - origin
+ rounded := roundToMultipleInt64(adjusted, opts.roundingInterval,
opts.mode, opts.CeilIsStrictlyGreater)
+ result := origin + rounded
+
+ return convertFromNanos(result, inputUnit), nil
+}
+
+// unitScaleFactor returns nanoseconds per unit for the given time unit
+func unitScaleFactor(unit arrow.TimeUnit) int64 {
+ switch unit {
+ case arrow.Second:
+ return 1_000_000_000
+ case arrow.Millisecond:
+ return 1_000_000
+ case arrow.Microsecond:
+ return 1_000
+ case arrow.Nanosecond:
+ return 1
+ default:
+ return 1
+ }
+}
+
+// canRoundInInputUnit checks if rounding can be done in the input unit
+// without converting to nanoseconds (true when rounding interval is evenly
divisible).
+func canRoundInInputUnit(inputUnit arrow.TimeUnit, roundingIntervalNanos
int64) bool {
+ return roundingIntervalNanos%unitScaleFactor(inputUnit) == 0
+}
+
+// convertToNanos converts a timestamp value to nanoseconds
+func convertToNanos(ts int64, unit arrow.TimeUnit) int64 {
+ return ts * unitScaleFactor(unit)
+}
+
+// convertFromNanos converts a nanosecond timestamp to the specified unit
+func convertFromNanos(nanos int64, unit arrow.TimeUnit) int64 {
+ return nanos / unitScaleFactor(unit)
+}
+
+func roundToMultipleInt64(value, multiple int64, mode RoundMode, strictCeil
bool) int64 {
+ if multiple == 0 || value%multiple == 0 {
+ if strictCeil && mode == RoundUp {
+ return value + multiple
+ }
+ return value
+ }
+
+ quotient := value / multiple
+ remainder := value % multiple
+
+ switch mode {
+ case RoundDown:
+ if remainder < 0 {
+ return (quotient - 1) * multiple
+ }
+ return quotient * multiple
+ case RoundUp:
+ if remainder > 0 || (strictCeil && remainder == 0) {
+ return (quotient + 1) * multiple
+ }
+ if remainder < 0 {
+ return quotient * multiple
+ }
+ return (quotient + 1) * multiple
+ case HalfUp, HalfDown, HalfToEven:
+ half := multiple / 2
+ absRemainder := remainder
+ if absRemainder < 0 {
+ absRemainder = -absRemainder
+ }
+
+ if absRemainder < half {
+ return quotient * multiple
+ } else if absRemainder > half {
+ if remainder > 0 {
+ return (quotient + 1) * multiple
+ }
+ return (quotient - 1) * multiple
+ } else {
+ // Exactly on the halfway point
+ switch mode {
+ case HalfDown:
+ if remainder > 0 {
+ return quotient * multiple
+ }
+ return (quotient - 1) * multiple
+ case HalfUp:
+ if remainder > 0 {
+ return (quotient + 1) * multiple
+ }
+ return quotient * multiple
+ case HalfToEven:
+ if quotient%2 == 0 {
+ return quotient * multiple
+ }
+ if remainder > 0 {
+ return (quotient + 1) * multiple
+ }
+ return (quotient - 1) * multiple
+ }
+ }
+ }
+ return quotient * multiple
+}
+
+// halfRoundPeriod performs half-rounding by finding the midpoint between
period start and end
+func halfRoundPeriod(t, periodStart, periodEnd time.Time) time.Time {
+ midPoint := periodStart.Add(periodEnd.Sub(periodStart) / 2)
+ if t.Before(midPoint) {
+ return periodStart
+ }
+ return periodEnd
+}
+
+// roundTimestampCalendar handles calendar-based rounding (year, quarter,
month, week, day).
+// Requires date arithmetic for variable-length periods and timezone-aware
boundaries.
+func roundTimestampCalendar(tsNanos int64, inputUnit arrow.TimeUnit, tz
*time.Location, opts roundTemporalState) (int64, error) {
+ // Convert to time.Time for calendar operations in the specified
timezone
+ secs := tsNanos / 1000000000
+ nanos := tsNanos % 1000000000
+ t := time.Unix(secs, nanos).In(tz)
+
+ var rounded time.Time
+
+ switch opts.Unit {
+ case RoundTemporalYear:
+ year := t.Year()
+ roundedYear := (year / int(opts.Multiple)) * int(opts.Multiple)
+ switch opts.mode {
+ case RoundDown:
+ rounded = time.Date(roundedYear, 1, 1, 0, 0, 0, 0, tz)
+ case RoundUp:
+ periodStart := time.Date(roundedYear, 1, 1, 0, 0, 0, 0,
tz)
+ if opts.CeilIsStrictlyGreater || !t.Equal(periodStart) {
+ roundedYear += int(opts.Multiple)
+ rounded = time.Date(roundedYear, 1, 1, 0, 0, 0,
0, tz)
+ } else {
+ rounded = periodStart
+ }
+ default:
+ yearStart := time.Date(roundedYear, 1, 1, 0, 0, 0, 0,
tz)
+ nextYear := roundedYear + int(opts.Multiple)
+ yearEnd := time.Date(nextYear, 1, 1, 0, 0, 0, 0, tz)
+ rounded = halfRoundPeriod(t, yearStart, yearEnd)
+ }
+
+ case RoundTemporalQuarter:
+ // Q1=Jan-Mar, Q2=Apr-Jun, Q3=Jul-Sep, Q4=Oct-Dec
+ month := int(t.Month())
+ year := t.Year()
+ totalQuarters := year*4 + (month-1)/3
+ roundedQuarters := (totalQuarters / int(opts.Multiple)) *
int(opts.Multiple)
+ roundedYear := roundedQuarters / 4
+ roundedQuarter := roundedQuarters % 4
+ roundedMonth := roundedQuarter*3 + 1 // First month of the
quarter
+
+ switch opts.mode {
+ case RoundDown:
+ rounded = time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ case RoundUp:
+ periodStart := time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ if opts.CeilIsStrictlyGreater || !t.Equal(periodStart) {
+ roundedQuarters += int(opts.Multiple)
+ roundedYear = roundedQuarters / 4
+ roundedQuarter = roundedQuarters % 4
+ roundedMonth = roundedQuarter*3 + 1
+ rounded = time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ } else {
+ rounded = periodStart
+ }
+ default:
+ quarterStart := time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ nextQuarterNum := roundedQuarters + int(opts.Multiple)
+ nextYear := nextQuarterNum / 4
+ nextQuarter := nextQuarterNum % 4
+ nextMonth := nextQuarter*3 + 1
+ quarterEnd := time.Date(nextYear,
time.Month(nextMonth), 1, 0, 0, 0, 0, tz)
+ rounded = halfRoundPeriod(t, quarterStart, quarterEnd)
+ }
+
+ case RoundTemporalMonth:
+ month := int(t.Month())
+ year := t.Year()
+ totalMonths := year*12 + month - 1
+ roundedMonths := (totalMonths / int(opts.Multiple)) *
int(opts.Multiple)
+ roundedYear := roundedMonths / 12
+ roundedMonth := (roundedMonths % 12) + 1
+
+ switch opts.mode {
+ case RoundDown:
+ rounded = time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ case RoundUp:
+ periodStart := time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ if opts.CeilIsStrictlyGreater || !t.Equal(periodStart) {
+ roundedMonths += int(opts.Multiple)
+ roundedYear = roundedMonths / 12
+ roundedMonth = (roundedMonths % 12) + 1
+ rounded = time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ } else {
+ rounded = periodStart
+ }
+ default:
+ monthStart := time.Date(roundedYear,
time.Month(roundedMonth), 1, 0, 0, 0, 0, tz)
+ nextMonthNum := roundedMonths + int(opts.Multiple)
+ nextYear := nextMonthNum / 12
+ nextMonth := (nextMonthNum % 12) + 1
+ monthEnd := time.Date(nextYear, time.Month(nextMonth),
1, 0, 0, 0, 0, tz)
+ rounded = halfRoundPeriod(t, monthStart, monthEnd)
+ }
+
+ case RoundTemporalWeek:
+ weekday := int(t.Weekday())
+ if opts.WeekStartsMonday {
+ weekday = (weekday + 6) % 7
+ }
+ startOfWeek := t.AddDate(0, 0, -weekday)
+ startOfWeek = time.Date(startOfWeek.Year(),
startOfWeek.Month(), startOfWeek.Day(), 0, 0, 0, 0, tz)
+
+ // Calculate N-week periods from epoch for Multiple > 1
+ epochInTz := time.Unix(0, 0).In(tz)
+ epochWeekday := int(epochInTz.Weekday())
+ if opts.WeekStartsMonday {
+ epochWeekday = (epochWeekday + 6) % 7
+ }
+ epochWeekStart := epochInTz.AddDate(0, 0, -epochWeekday)
+ epochWeekStart = time.Date(epochWeekStart.Year(),
epochWeekStart.Month(), epochWeekStart.Day(), 0, 0, 0, 0, tz)
+
+ daysSinceEpochWeek :=
int(startOfWeek.Sub(epochWeekStart).Hours() / 24)
+ weeksSinceEpoch := daysSinceEpochWeek / 7
+ roundedWeeks := (weeksSinceEpoch / int(opts.Multiple)) *
int(opts.Multiple)
+ roundedWeekStart := epochWeekStart.AddDate(0, 0, roundedWeeks*7)
+
+ switch opts.mode {
+ case RoundDown:
+ rounded = roundedWeekStart
+ case RoundUp:
+ if opts.CeilIsStrictlyGreater ||
!t.Equal(roundedWeekStart) {
+ rounded = roundedWeekStart.AddDate(0, 0,
7*int(opts.Multiple))
+ } else {
+ rounded = roundedWeekStart
+ }
+ default:
+ weekEnd := roundedWeekStart.AddDate(0, 0,
7*int(opts.Multiple))
+ rounded = halfRoundPeriod(t, roundedWeekStart, weekEnd)
+ }
+
+ case RoundTemporalDay:
+ startOfDay := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0,
0, tz)
+
+ switch opts.mode {
+ case RoundDown:
+ rounded = startOfDay
+ case RoundUp:
+ if opts.CeilIsStrictlyGreater || !t.Equal(startOfDay) {
+ rounded = startOfDay.AddDate(0, 0, 1)
+ } else {
+ rounded = startOfDay
+ }
+ default:
+ nextDay := startOfDay.AddDate(0, 0, 1)
+ rounded = halfRoundPeriod(t, startOfDay, nextDay)
+ }
+
+ default:
+ return 0, fmt.Errorf("%w: unsupported calendar unit",
arrow.ErrNotImplemented)
+ }
+
+ // Convert back to the input unit
+ roundedNanos := rounded.UnixNano()
+ return convertFromNanos(roundedNanos, inputUnit), nil
+}
+
+// Kernel execution functions for temporal rounding
+func FloorTemporalKernel(ctx *exec.KernelCtx, batch *exec.ExecSpan, out
*exec.ExecResult) error {
+ state := ctx.State.(roundTemporalState)
+ state.mode = RoundDown
+ return roundTemporalExec(ctx, batch, out, state)
+}
+
+func CeilTemporalKernel(ctx *exec.KernelCtx, batch *exec.ExecSpan, out
*exec.ExecResult) error {
+ state := ctx.State.(roundTemporalState)
+ state.mode = RoundUp
+ return roundTemporalExec(ctx, batch, out, state)
+}
+
+func RoundTemporalKernel(ctx *exec.KernelCtx, batch *exec.ExecSpan, out
*exec.ExecResult) error {
+ state := ctx.State.(roundTemporalState)
+ state.mode = HalfUp
+ return roundTemporalExec(ctx, batch, out, state)
+}
+
+func roundTemporalExec(ctx *exec.KernelCtx, batch *exec.ExecSpan, out
*exec.ExecResult, state roundTemporalState) error {
+ input := &batch.Values[0].Array
+
+ // Handle date types by converting to timestamp equivalents
+ switch input.Type.ID() {
+ case arrow.DATE32:
+ // Date32 stores days since epoch as int32, treat as
timestamp[s] at midnight
+ fn := func(_ *exec.KernelCtx, days int32, e *error) int32 {
+ // Convert days to seconds (timestamp at midnight UTC)
+ tsSeconds := int64(days) * 86400
+ result, err := roundTimestamp(tsSeconds, arrow.Second,
nil, state)
+ if err != nil {
+ *e = err
+ return 0
+ }
+ // Convert back to days
+ return int32(result / 86400)
+ }
+ return ScalarUnaryNotNull(fn)(ctx, batch, out)
+
+ case arrow.DATE64:
+ // Date64 stores milliseconds since epoch, treat as
timestamp[ms]
+ fn := func(_ *exec.KernelCtx, ms int64, e *error) int64 {
+ result, err := roundTimestamp(ms, arrow.Millisecond,
nil, state)
+ if err != nil {
+ *e = err
+ return 0
+ }
+ return result
+ }
+ return ScalarUnaryNotNull(fn)(ctx, batch, out)
+
+ case arrow.TIME32:
+ // Time32 stores time-of-day in seconds or milliseconds
+ // Rounding wraps at day boundaries (modulo 24 hours)
+ timeType := input.Type.(*arrow.Time32Type)
+ fn := func(_ *exec.KernelCtx, time int32, e *error) int32 {
+ // Convert to int64 for rounding
+ result, err := roundTimestamp(int64(time),
timeType.Unit, nil, state)
+ if err != nil {
+ *e = err
+ return 0
+ }
+ // Wrap at day boundary
+ var dayInUnit int64
+ if timeType.Unit == arrow.Second {
+ dayInUnit = 86400 // 24 hours in seconds
+ } else {
+ dayInUnit = 86400000 // 24 hours in milliseconds
+ }
+ wrapped := result % dayInUnit
+ if wrapped < 0 {
+ wrapped += dayInUnit
+ }
+ return int32(wrapped)
+ }
+ return ScalarUnaryNotNull(fn)(ctx, batch, out)
+
+ case arrow.TIME64:
+ // Time64 stores time-of-day in microseconds or nanoseconds
+ // Rounding wraps at day boundaries (modulo 24 hours)
+ timeType := input.Type.(*arrow.Time64Type)
+ fn := func(_ *exec.KernelCtx, time int64, e *error) int64 {
+ result, err := roundTimestamp(time, timeType.Unit, nil,
state)
+ if err != nil {
+ *e = err
+ return 0
+ }
+ // Wrap at day boundary
+ var dayInUnit int64
+ if timeType.Unit == arrow.Microsecond {
+ dayInUnit = 86400000000 // 24 hours in
microseconds
+ } else {
+ dayInUnit = 86400000000000 // 24 hours in
nanoseconds
+ }
+ wrapped := result % dayInUnit
+ if wrapped < 0 {
+ wrapped += dayInUnit
+ }
+ return wrapped
+ }
+ return ScalarUnaryNotNull(fn)(ctx, batch, out)
+ }
+
+ // Handle timestamp types
+ inputType := input.Type.(arrow.TemporalWithUnit)
+
+ // Extract timezone if present (for timestamp types)
+ var tz *time.Location
+ if tsType, ok := input.Type.(*arrow.TimestampType); ok &&
tsType.TimeZone != "" {
+ var err error
+ tz, err = time.LoadLocation(tsType.TimeZone)
+ if err != nil {
+ return fmt.Errorf("%w: invalid timezone %q: %v",
arrow.ErrInvalid, tsType.TimeZone, err)
+ }
+ }
+
+ fn := func(_ *exec.KernelCtx, ts int64, e *error) int64 {
+ result, err := roundTimestamp(ts, inputType.TimeUnit(), tz,
state)
+ if err != nil {
+ *e = err
+ return 0
+ }
+ return result
+ }
+
+ switch inputType.TimeUnit() {
+ case arrow.Second, arrow.Millisecond, arrow.Microsecond,
arrow.Nanosecond:
+ return ScalarUnaryNotNull(fn)(ctx, batch, out)
+ default:
+ return fmt.Errorf("%w: unsupported time unit",
arrow.ErrNotImplemented)
+ }
+}
+
+type timestampUnitMatcher struct {
+ unit arrow.TimeUnit
+}
+
+func (m *timestampUnitMatcher) Matches(typ arrow.DataType) bool {
+ if ts, ok := typ.(*arrow.TimestampType); ok {
+ return ts.Unit == m.unit
+ }
+ return false
+}
+
+func (m *timestampUnitMatcher) String() string {
+ return "timestamp(unit=" + m.unit.String() + ")"
+}
+
+func (m *timestampUnitMatcher) Equals(other exec.TypeMatcher) bool {
+ if o, ok := other.(*timestampUnitMatcher); ok {
+ return m.unit == o.unit
+ }
+ return false
+}
+
+type dateTypeMatcher struct {
+ dateTypeID arrow.Type
+}
+
+func (m *dateTypeMatcher) Matches(typ arrow.DataType) bool {
+ return typ.ID() == m.dateTypeID
+}
+
+func (m *dateTypeMatcher) String() string {
+ if m.dateTypeID == arrow.DATE32 {
+ return "date32"
+ }
+ return "date64"
+}
+
+func (m *dateTypeMatcher) Equals(other exec.TypeMatcher) bool {
+ if o, ok := other.(*dateTypeMatcher); ok {
+ return m.dateTypeID == o.dateTypeID
+ }
+ return false
+}
Review Comment:
can you move this to `exec/kernel.go` and expose it with a function like the
other matchers we have?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]