This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new b85139204a ARROW-13530: [C++] Implement cumulative sum compute function
b85139204a is described below
commit b85139204a7d6f9e92daf8d6987695ce133a1aa6
Author: JabariBooker <[email protected]>
AuthorDate: Tue May 31 15:23:09 2022 +0200
ARROW-13530: [C++] Implement cumulative sum compute function
Creating new compute function to perform a cumulative sum on a given
array/vector.
Closes #12460 from JabariBooker/ARROW-13530
Lead-authored-by: JabariBooker <[email protected]>
Co-authored-by: Eduardo Ponce <[email protected]>
Co-authored-by: Jabari Booker <[email protected]>
Co-authored-by: Jabari Booker <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
---
cpp/src/arrow/CMakeLists.txt | 1 +
cpp/src/arrow/compute/api_vector.cc | 26 +
cpp/src/arrow/compute/api_vector.h | 27 +
cpp/src/arrow/compute/kernels/CMakeLists.txt | 1 +
.../compute/kernels/base_arithmetic_internal.h | 602 +++++++++++++++++++++
cpp/src/arrow/compute/kernels/codegen_internal.h | 31 ++
cpp/src/arrow/compute/kernels/scalar_arithmetic.cc | 593 +-------------------
.../arrow/compute/kernels/vector_cumulative_ops.cc | 235 ++++++++
.../compute/kernels/vector_cumulative_ops_test.cc | 357 ++++++++++++
cpp/src/arrow/compute/kernels/vector_sort.cc | 8 +-
cpp/src/arrow/compute/registry.cc | 1 +
cpp/src/arrow/compute/registry_internal.h | 1 +
docs/source/cpp/compute.rst | 26 +
docs/source/python/api/compute.rst | 16 +
python/pyarrow/_compute.pyx | 28 +
python/pyarrow/compute.py | 1 +
python/pyarrow/includes/libarrow.pxd | 6 +
python/pyarrow/tests/test_compute.py | 53 ++
18 files changed, 1415 insertions(+), 598 deletions(-)
diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index ec6cada1cd..fd2f10db2f 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -437,6 +437,7 @@ if(ARROW_COMPUTE)
compute/kernels/scalar_validity.cc
compute/kernels/util_internal.cc
compute/kernels/vector_array_sort.cc
+ compute/kernels/vector_cumulative_ops.cc
compute/kernels/vector_hash.cc
compute/kernels/vector_nested.cc
compute/kernels/vector_replace.cc
diff --git a/cpp/src/arrow/compute/api_vector.cc
b/cpp/src/arrow/compute/api_vector.cc
index a5cb61d6b5..e3db022536 100644
--- a/cpp/src/arrow/compute/api_vector.cc
+++ b/cpp/src/arrow/compute/api_vector.cc
@@ -135,6 +135,10 @@ static auto kPartitionNthOptionsType =
GetFunctionOptionsType<PartitionNthOption
static auto kSelectKOptionsType = GetFunctionOptionsType<SelectKOptions>(
DataMember("k", &SelectKOptions::k),
DataMember("sort_keys", &SelectKOptions::sort_keys));
+static auto kCumulativeSumOptionsType =
GetFunctionOptionsType<CumulativeSumOptions>(
+ DataMember("start", &CumulativeSumOptions::start),
+ DataMember("skip_nulls", &CumulativeSumOptions::skip_nulls),
+ DataMember("check_overflow", &CumulativeSumOptions::check_overflow));
} // namespace
} // namespace internal
@@ -176,6 +180,18 @@ SelectKOptions::SelectKOptions(int64_t k,
std::vector<SortKey> sort_keys)
sort_keys(std::move(sort_keys)) {}
constexpr char SelectKOptions::kTypeName[];
+CumulativeSumOptions::CumulativeSumOptions(double start, bool skip_nulls,
+ bool check_overflow)
+ : CumulativeSumOptions(std::make_shared<DoubleScalar>(start), skip_nulls,
+ check_overflow) {}
+CumulativeSumOptions::CumulativeSumOptions(std::shared_ptr<Scalar> start, bool
skip_nulls,
+ bool check_overflow)
+ : FunctionOptions(internal::kCumulativeSumOptionsType),
+ start(std::move(start)),
+ skip_nulls(skip_nulls),
+ check_overflow(check_overflow) {}
+constexpr char CumulativeSumOptions::kTypeName[];
+
namespace internal {
void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType));
@@ -185,6 +201,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kSortOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kPartitionNthOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kSelectKOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kCumulativeSumOptionsType));
}
} // namespace internal
@@ -325,6 +342,15 @@ Result<std::shared_ptr<Array>> DropNull(const Array&
values, ExecContext* ctx) {
return out.make_array();
}
+// ----------------------------------------------------------------------
+// Cumulative functions
+
+Result<Datum> CumulativeSum(const Datum& values, const CumulativeSumOptions&
options,
+ ExecContext* ctx) {
+ auto func_name = (options.check_overflow) ? "cumulative_sum_checked" :
"cumulative_sum";
+ return CallFunction(func_name, {Datum(values)}, &options, ctx);
+}
+
// ----------------------------------------------------------------------
// Deprecated functions
diff --git a/cpp/src/arrow/compute/api_vector.h
b/cpp/src/arrow/compute/api_vector.h
index 9e53cfcf64..b5daddb17b 100644
--- a/cpp/src/arrow/compute/api_vector.h
+++ b/cpp/src/arrow/compute/api_vector.h
@@ -188,6 +188,27 @@ class ARROW_EXPORT PartitionNthOptions : public
FunctionOptions {
NullPlacement null_placement;
};
+/// \brief Options for cumulative sum function
+class ARROW_EXPORT CumulativeSumOptions : public FunctionOptions {
+ public:
+ explicit CumulativeSumOptions(double start = 0, bool skip_nulls = false,
+ bool check_overflow = false);
+ explicit CumulativeSumOptions(std::shared_ptr<Scalar> start, bool skip_nulls
= false,
+ bool check_overflow = false);
+ static constexpr char const kTypeName[] = "CumulativeSumOptions";
+ static CumulativeSumOptions Defaults() { return CumulativeSumOptions(); }
+
+ /// Optional starting value for cumulative operation computation
+ std::shared_ptr<Scalar> start;
+
+ /// If true, nulls in the input are ignored and produce a corresponding null
output.
+ /// When false, the first null encountered is propagated through the
remaining output.
+ bool skip_nulls = false;
+
+ /// When true, returns an Invalid Status when overflow is detected
+ bool check_overflow = false;
+};
+
/// @}
/// \brief Filter with a boolean selection filter
@@ -522,6 +543,12 @@ Result<Datum> DictionaryEncode(
const DictionaryEncodeOptions& options =
DictionaryEncodeOptions::Defaults(),
ExecContext* ctx = NULLPTR);
+ARROW_EXPORT
+Result<Datum> CumulativeSum(
+ const Datum& values,
+ const CumulativeSumOptions& options = CumulativeSumOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
// ----------------------------------------------------------------------
// Deprecated functions
diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt
b/cpp/src/arrow/compute/kernels/CMakeLists.txt
index 0a7f619112..780699886d 100644
--- a/cpp/src/arrow/compute/kernels/CMakeLists.txt
+++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt
@@ -48,6 +48,7 @@ add_arrow_benchmark(scalar_temporal_benchmark PREFIX
"arrow-compute")
add_arrow_compute_test(vector_test
SOURCES
+ vector_cumulative_ops_test.cc
vector_hash_test.cc
vector_nested_test.cc
vector_replace_test.cc
diff --git a/cpp/src/arrow/compute/kernels/base_arithmetic_internal.h
b/cpp/src/arrow/compute/kernels/base_arithmetic_internal.h
new file mode 100644
index 0000000000..1707ed7c13
--- /dev/null
+++ b/cpp/src/arrow/compute/kernels/base_arithmetic_internal.h
@@ -0,0 +1,602 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/int_util_internal.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+
+using internal::AddWithOverflow;
+using internal::DivideWithOverflow;
+using internal::MultiplyWithOverflow;
+using internal::NegateWithOverflow;
+using internal::SubtractWithOverflow;
+
+namespace compute {
+namespace internal {
+
+struct Add {
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_floating_value<T> Call(KernelContext*, Arg0 left,
Arg1 right,
+ Status*) {
+ return left + right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_unsigned_integer_value<T> Call(KernelContext*,
Arg0 left,
+ Arg1 right,
Status*) {
+ return left + right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_signed_integer_value<T> Call(KernelContext*, Arg0
left,
+ Arg1 right, Status*)
{
+ return arrow::internal::SafeSignedAdd(left, right);
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1
right, Status*) {
+ return left + right;
+ }
+};
+
+struct AddChecked {
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_integer_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
+ Status* st) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
+ T result = 0;
+ if (ARROW_PREDICT_FALSE(AddWithOverflow(left, right, &result))) {
+ *st = Status::Invalid("overflow");
+ }
+ return result;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_value<T> Call(KernelContext*, Arg0 left, Arg1
right,
+ Status*) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
+ return left + right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1
right, Status*) {
+ return left + right;
+ }
+};
+
+template <int64_t multiple>
+struct AddTimeDuration {
+ template <typename T, typename Arg0, typename Arg1>
+ static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
+ T result =
+ arrow::internal::SafeSignedAdd(static_cast<T>(left),
static_cast<T>(right));
+ if (result < 0 || multiple <= result) {
+ *st = Status::Invalid(result, " is not within the acceptable range of ",
"[0, ",
+ multiple, ") s");
+ }
+ return result;
+ }
+};
+
+template <int64_t multiple>
+struct AddTimeDurationChecked {
+ template <typename T, typename Arg0, typename Arg1>
+ static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
+ T result = 0;
+ if (ARROW_PREDICT_FALSE(
+ AddWithOverflow(static_cast<T>(left), static_cast<T>(right),
&result))) {
+ *st = Status::Invalid("overflow");
+ }
+ if (result < 0 || multiple <= result) {
+ *st = Status::Invalid(result, " is not within the acceptable range of ",
"[0, ",
+ multiple, ") s");
+ }
+ return result;
+ }
+};
+
+struct AbsoluteValue {
+ template <typename T, typename Arg>
+ static constexpr enable_if_floating_value<Arg, T> Call(KernelContext*, Arg
arg,
+ Status*) {
+ return std::fabs(arg);
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_unsigned_integer_value<Arg, T>
Call(KernelContext*, Arg arg,
+ Status*) {
+ return arg;
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_signed_integer_value<Arg, T> Call(KernelContext*,
Arg arg,
+ Status* st) {
+ return (arg < 0) ? arrow::internal::SafeSignedNegate(arg) : arg;
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_decimal_value<Arg, T> Call(KernelContext*, Arg
arg,
+ Status*) {
+ return arg.Abs();
+ }
+};
+
+struct AbsoluteValueChecked {
+ template <typename T, typename Arg>
+ static enable_if_signed_integer_value<Arg, T> Call(KernelContext*, Arg arg,
+ Status* st) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ if (arg == std::numeric_limits<Arg>::min()) {
+ *st = Status::Invalid("overflow");
+ return arg;
+ }
+ return std::abs(arg);
+ }
+
+ template <typename T, typename Arg>
+ static enable_if_unsigned_integer_value<Arg, T> Call(KernelContext* ctx, Arg
arg,
+ Status* st) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ return arg;
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_floating_value<Arg, T> Call(KernelContext*, Arg
arg,
+ Status* st) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ return std::fabs(arg);
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_decimal_value<Arg, T> Call(KernelContext*, Arg
arg,
+ Status*) {
+ return arg.Abs();
+ }
+};
+
+struct Subtract {
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_floating_value<T> Call(KernelContext*, Arg0 left,
Arg1 right,
+ Status*) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
+ return left - right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_unsigned_integer_value<T> Call(KernelContext*,
Arg0 left,
+ Arg1 right,
Status*) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
+ return left - right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_signed_integer_value<T> Call(KernelContext*, Arg0
left,
+ Arg1 right, Status*)
{
+ return arrow::internal::SafeSignedSubtract(left, right);
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1
right, Status*) {
+ return left + (-right);
+ }
+};
+
+struct SubtractChecked {
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_integer_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
+ Status* st) {
+ T result = 0;
+ if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, right, &result))) {
+ *st = Status::Invalid("overflow");
+ }
+ return result;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_value<T> Call(KernelContext*, Arg0 left, Arg1
right,
+ Status*) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
+ return left - right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1
right, Status*) {
+ return left + (-right);
+ }
+};
+
+struct SubtractDate32 {
+ static constexpr int64_t kSecondsInDay = 86400;
+
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr T Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
+ return arrow::internal::SafeSignedSubtract(left, right) * kSecondsInDay;
+ }
+};
+
+struct SubtractCheckedDate32 {
+ static constexpr int64_t kSecondsInDay = 86400;
+
+ template <typename T, typename Arg0, typename Arg1>
+ static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
+ T result = 0;
+ if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, right, &result))) {
+ *st = Status::Invalid("overflow");
+ }
+ if (ARROW_PREDICT_FALSE(MultiplyWithOverflow(result, kSecondsInDay,
&result))) {
+ *st = Status::Invalid("overflow");
+ }
+ return result;
+ }
+};
+
+template <int64_t multiple>
+struct SubtractTimeDuration {
+ template <typename T, typename Arg0, typename Arg1>
+ static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
+ T result = arrow::internal::SafeSignedSubtract(left,
static_cast<T>(right));
+ if (result < 0 || multiple <= result) {
+ *st = Status::Invalid(result, " is not within the acceptable range of ",
"[0, ",
+ multiple, ") s");
+ }
+ return result;
+ }
+};
+
+template <int64_t multiple>
+struct SubtractTimeDurationChecked {
+ template <typename T, typename Arg0, typename Arg1>
+ static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
+ T result = 0;
+ if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, static_cast<T>(right),
&result))) {
+ *st = Status::Invalid("overflow");
+ }
+ if (result < 0 || multiple <= result) {
+ *st = Status::Invalid(result, " is not within the acceptable range of ",
"[0, ",
+ multiple, ") s");
+ }
+ return result;
+ }
+};
+
+struct Multiply {
+ static_assert(std::is_same<decltype(int8_t() * int8_t()), int32_t>::value,
"");
+ static_assert(std::is_same<decltype(uint8_t() * uint8_t()), int32_t>::value,
"");
+ static_assert(std::is_same<decltype(int16_t() * int16_t()), int32_t>::value,
"");
+ static_assert(std::is_same<decltype(uint16_t() * uint16_t()),
int32_t>::value, "");
+ static_assert(std::is_same<decltype(int32_t() * int32_t()), int32_t>::value,
"");
+ static_assert(std::is_same<decltype(uint32_t() * uint32_t()),
uint32_t>::value, "");
+ static_assert(std::is_same<decltype(int64_t() * int64_t()), int64_t>::value,
"");
+ static_assert(std::is_same<decltype(uint64_t() * uint64_t()),
uint64_t>::value, "");
+
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_floating_value<T> Call(KernelContext*, T left, T
right,
+ Status*) {
+ return left * right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_t<
+ is_unsigned_integer_value<T>::value && !std::is_same<T,
uint16_t>::value, T>
+ Call(KernelContext*, T left, T right, Status*) {
+ return left * right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_t<
+ is_signed_integer_value<T>::value && !std::is_same<T, int16_t>::value, T>
+ Call(KernelContext*, T left, T right, Status*) {
+ return to_unsigned(left) * to_unsigned(right);
+ }
+
+ // Multiplication of 16 bit integer types implicitly promotes to signed 32
bit
+ // integer. However, some inputs may nevertheless overflow (which triggers
undefined
+ // behaviour). Therefore we first cast to 32 bit unsigned integers where
overflow is
+ // well defined.
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_same<T, int16_t, T> Call(KernelContext*, int16_t
left,
+ int16_t right, Status*) {
+ return static_cast<uint32_t>(left) * static_cast<uint32_t>(right);
+ }
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_same<T, uint16_t, T> Call(KernelContext*,
uint16_t left,
+ uint16_t right,
Status*) {
+ return static_cast<uint32_t>(left) * static_cast<uint32_t>(right);
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1
right, Status*) {
+ return left * right;
+ }
+};
+
+struct MultiplyChecked {
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_integer_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
+ Status* st) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
+ T result = 0;
+ if (ARROW_PREDICT_FALSE(MultiplyWithOverflow(left, right, &result))) {
+ *st = Status::Invalid("overflow");
+ }
+ return result;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_value<T> Call(KernelContext*, Arg0 left, Arg1
right,
+ Status*) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
+ return left * right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1
right, Status*) {
+ return left * right;
+ }
+};
+
+struct Divide {
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_value<T> Call(KernelContext*, Arg0 left, Arg1
right,
+ Status*) {
+ return left / right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_integer_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
+ Status* st) {
+ T result;
+ if (ARROW_PREDICT_FALSE(DivideWithOverflow(left, right, &result))) {
+ if (right == 0) {
+ *st = Status::Invalid("divide by zero");
+ } else {
+ result = 0;
+ }
+ }
+ return result;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
+ Status* st) {
+ if (right == Arg1()) {
+ *st = Status::Invalid("Divide by zero");
+ return T();
+ } else {
+ return left / right;
+ }
+ }
+};
+
+struct DivideChecked {
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_integer_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
+ Status* st) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
+ T result;
+ if (ARROW_PREDICT_FALSE(DivideWithOverflow(left, right, &result))) {
+ if (right == 0) {
+ *st = Status::Invalid("divide by zero");
+ } else {
+ *st = Status::Invalid("overflow");
+ }
+ }
+ return result;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_value<T> Call(KernelContext*, Arg0 left, Arg1
right,
+ Status* st) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
+ if (ARROW_PREDICT_FALSE(right == 0)) {
+ *st = Status::Invalid("divide by zero");
+ return 0;
+ }
+ return left / right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_decimal_value<T> Call(KernelContext* ctx, Arg0 left, Arg1
right,
+ Status* st) {
+ return Divide::Call<T>(ctx, left, right, st);
+ }
+};
+
+struct Negate {
+ template <typename T, typename Arg>
+ static constexpr enable_if_floating_value<T> Call(KernelContext*, Arg arg,
Status*) {
+ return -arg;
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_unsigned_integer_value<T> Call(KernelContext*,
Arg arg,
+ Status*) {
+ return ~arg + 1;
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_signed_integer_value<T> Call(KernelContext*, Arg
arg,
+ Status*) {
+ return arrow::internal::SafeSignedNegate(arg);
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_decimal_value<Arg, T> Call(KernelContext*, Arg
arg,
+ Status*) {
+ return arg.Negate();
+ }
+};
+
+struct NegateChecked {
+ template <typename T, typename Arg>
+ static enable_if_signed_integer_value<Arg, T> Call(KernelContext*, Arg arg,
+ Status* st) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ T result = 0;
+ if (ARROW_PREDICT_FALSE(NegateWithOverflow(arg, &result))) {
+ *st = Status::Invalid("overflow");
+ }
+ return result;
+ }
+
+ template <typename T, typename Arg>
+ static enable_if_unsigned_integer_value<Arg, T> Call(KernelContext* ctx, Arg
arg,
+ Status* st) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ DCHECK(false) << "This is included only for the purposes of
instantiability from the "
+ "arithmetic kernel generator";
+ return 0;
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_floating_value<Arg, T> Call(KernelContext*, Arg
arg,
+ Status* st) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ return -arg;
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_decimal_value<Arg, T> Call(KernelContext*, Arg
arg,
+ Status*) {
+ return arg.Negate();
+ }
+};
+
+struct Power {
+ ARROW_NOINLINE
+ static uint64_t IntegerPower(uint64_t base, uint64_t exp) {
+ // right to left O(logn) power
+ uint64_t pow = 1;
+ while (exp) {
+ pow *= (exp & 1) ? base : 1;
+ base *= base;
+ exp >>= 1;
+ }
+ return pow;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_integer_value<T> Call(KernelContext*, T base, T exp,
Status* st) {
+ if (exp < 0) {
+ *st = Status::Invalid("integers to negative integer powers are not
allowed");
+ return 0;
+ }
+ return static_cast<T>(IntegerPower(base, exp));
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_value<T> Call(KernelContext*, T base, T exp,
Status*) {
+ return std::pow(base, exp);
+ }
+};
+
+struct PowerChecked {
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_integer_value<T> Call(KernelContext*, Arg0 base, Arg1 exp,
+ Status* st) {
+ if (exp < 0) {
+ *st = Status::Invalid("integers to negative integer powers are not
allowed");
+ return 0;
+ } else if (exp == 0) {
+ return 1;
+ }
+ // left to right O(logn) power with overflow checks
+ bool overflow = false;
+ uint64_t bitmask =
+ 1ULL << (63 - bit_util::CountLeadingZeros(static_cast<uint64_t>(exp)));
+ T pow = 1;
+ while (bitmask) {
+ overflow |= MultiplyWithOverflow(pow, pow, &pow);
+ if (exp & bitmask) {
+ overflow |= MultiplyWithOverflow(pow, base, &pow);
+ }
+ bitmask >>= 1;
+ }
+ if (overflow) {
+ *st = Status::Invalid("overflow");
+ }
+ return pow;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_value<T> Call(KernelContext*, Arg0 base, Arg1 exp,
Status*) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
+ return std::pow(base, exp);
+ }
+};
+
+struct SquareRoot {
+ template <typename T, typename Arg>
+ static enable_if_floating_value<Arg, T> Call(KernelContext*, Arg arg,
Status*) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ if (arg < 0.0) {
+ return std::numeric_limits<T>::quiet_NaN();
+ }
+ return std::sqrt(arg);
+ }
+};
+
+struct SquareRootChecked {
+ template <typename T, typename Arg>
+ static enable_if_floating_value<Arg, T> Call(KernelContext*, Arg arg,
Status* st) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ if (arg < 0.0) {
+ *st = Status::Invalid("square root of negative number");
+ return arg;
+ }
+ return std::sqrt(arg);
+ }
+};
+
+struct Sign {
+ template <typename T, typename Arg>
+ static constexpr enable_if_floating_value<Arg, T> Call(KernelContext*, Arg
arg,
+ Status*) {
+ return std::isnan(arg) ? arg : ((arg == 0) ? 0 : (std::signbit(arg) ? -1 :
1));
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_unsigned_integer_value<Arg, T>
Call(KernelContext*, Arg arg,
+ Status*) {
+ return (arg > 0) ? 1 : 0;
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_signed_integer_value<Arg, T> Call(KernelContext*,
Arg arg,
+ Status*) {
+ return (arg > 0) ? 1 : ((arg == 0) ? 0 : -1);
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_decimal_value<Arg, T> Call(KernelContext*, Arg
arg,
+ Status*) {
+ return (arg == 0) ? 0 : arg.Sign();
+ }
+};
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h
b/cpp/src/arrow/compute/kernels/codegen_internal.h
index fa50427bc3..6d31c1fe24 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.h
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.h
@@ -1134,6 +1134,37 @@ ArrayKernelExec
GeneratePhysicalInteger(detail::GetTypeId get_id) {
}
}
+template <template <typename...> class KernelGenerator, typename Op,
typename... Args>
+ArrayKernelExec ArithmeticExecFromOp(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::INT8:
+ return KernelGenerator<Int8Type, Int8Type, Op, Args...>::Exec;
+ case Type::UINT8:
+ return KernelGenerator<UInt8Type, UInt8Type, Op, Args...>::Exec;
+ case Type::INT16:
+ return KernelGenerator<Int16Type, Int16Type, Op, Args...>::Exec;
+ case Type::UINT16:
+ return KernelGenerator<UInt16Type, UInt16Type, Op, Args...>::Exec;
+ case Type::INT32:
+ return KernelGenerator<Int32Type, Int32Type, Op, Args...>::Exec;
+ case Type::UINT32:
+ return KernelGenerator<UInt32Type, UInt32Type, Op, Args...>::Exec;
+ case Type::DURATION:
+ case Type::INT64:
+ case Type::TIMESTAMP:
+ return KernelGenerator<Int64Type, Int64Type, Op, Args...>::Exec;
+ case Type::UINT64:
+ return KernelGenerator<UInt64Type, UInt64Type, Op, Args...>::Exec;
+ case Type::FLOAT:
+ return KernelGenerator<FloatType, FloatType, Op, Args...>::Exec;
+ case Type::DOUBLE:
+ return KernelGenerator<DoubleType, DoubleType, Op, Args...>::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
template <template <typename... Args> class Generator, typename... Args>
ArrayKernelExec GeneratePhysicalNumeric(detail::GetTypeId get_id) {
switch (get_id.id) {
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
index 4365ad4e76..0742fb32c5 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
@@ -25,6 +25,7 @@
#include "arrow/compare.h"
#include "arrow/compute/api_scalar.h"
#include "arrow/compute/cast.h"
+#include "arrow/compute/kernels/base_arithmetic_internal.h"
#include "arrow/compute/kernels/common.h"
#include "arrow/compute/kernels/util_internal.h"
#include "arrow/type.h"
@@ -81,566 +82,6 @@ bool IsPositive(const Scalar& scalar) {
// N.B. take care not to conflict with type_traits.h as that can cause
surprises in a
// unity build
-struct AbsoluteValue {
- template <typename T, typename Arg>
- static constexpr enable_if_floating_value<Arg, T> Call(KernelContext*, Arg
arg,
- Status*) {
- return std::fabs(arg);
- }
-
- template <typename T, typename Arg>
- static constexpr enable_if_unsigned_integer_value<Arg, T>
Call(KernelContext*, Arg arg,
- Status*) {
- return arg;
- }
-
- template <typename T, typename Arg>
- static constexpr enable_if_signed_integer_value<Arg, T> Call(KernelContext*,
Arg arg,
- Status* st) {
- return (arg < 0) ? arrow::internal::SafeSignedNegate(arg) : arg;
- }
-
- template <typename T, typename Arg>
- static constexpr enable_if_decimal_value<Arg, T> Call(KernelContext*, Arg
arg,
- Status*) {
- return arg.Abs();
- }
-};
-
-struct AbsoluteValueChecked {
- template <typename T, typename Arg>
- static enable_if_signed_integer_value<Arg, T> Call(KernelContext*, Arg arg,
- Status* st) {
- static_assert(std::is_same<T, Arg>::value, "");
- if (arg == std::numeric_limits<Arg>::min()) {
- *st = Status::Invalid("overflow");
- return arg;
- }
- return std::abs(arg);
- }
-
- template <typename T, typename Arg>
- static enable_if_unsigned_integer_value<Arg, T> Call(KernelContext* ctx, Arg
arg,
- Status* st) {
- static_assert(std::is_same<T, Arg>::value, "");
- return arg;
- }
-
- template <typename T, typename Arg>
- static constexpr enable_if_floating_value<Arg, T> Call(KernelContext*, Arg
arg,
- Status* st) {
- static_assert(std::is_same<T, Arg>::value, "");
- return std::fabs(arg);
- }
-
- template <typename T, typename Arg>
- static constexpr enable_if_decimal_value<Arg, T> Call(KernelContext*, Arg
arg,
- Status*) {
- return arg.Abs();
- }
-};
-
-struct Add {
- template <typename T, typename Arg0, typename Arg1>
- static constexpr enable_if_floating_value<T> Call(KernelContext*, Arg0 left,
Arg1 right,
- Status*) {
- return left + right;
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static constexpr enable_if_unsigned_integer_value<T> Call(KernelContext*,
Arg0 left,
- Arg1 right,
Status*) {
- return left + right;
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static constexpr enable_if_signed_integer_value<T> Call(KernelContext*, Arg0
left,
- Arg1 right, Status*)
{
- return arrow::internal::SafeSignedAdd(left, right);
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1
right, Status*) {
- return left + right;
- }
-};
-
-struct AddChecked {
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_integer_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
- Status* st) {
- static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
- T result = 0;
- if (ARROW_PREDICT_FALSE(AddWithOverflow(left, right, &result))) {
- *st = Status::Invalid("overflow");
- }
- return result;
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_floating_value<T> Call(KernelContext*, Arg0 left, Arg1
right,
- Status*) {
- static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
- return left + right;
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1
right, Status*) {
- return left + right;
- }
-};
-
-struct Subtract {
- template <typename T, typename Arg0, typename Arg1>
- static constexpr enable_if_floating_value<T> Call(KernelContext*, Arg0 left,
Arg1 right,
- Status*) {
- static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
- return left - right;
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static constexpr enable_if_unsigned_integer_value<T> Call(KernelContext*,
Arg0 left,
- Arg1 right,
Status*) {
- static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
- return left - right;
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static constexpr enable_if_signed_integer_value<T> Call(KernelContext*, Arg0
left,
- Arg1 right, Status*)
{
- return arrow::internal::SafeSignedSubtract(left, right);
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1
right, Status*) {
- return left + (-right);
- }
-};
-
-struct SubtractChecked {
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_integer_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
- Status* st) {
- T result = 0;
- if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, right, &result))) {
- *st = Status::Invalid("overflow");
- }
- return result;
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_floating_value<T> Call(KernelContext*, Arg0 left, Arg1
right,
- Status*) {
- static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
- return left - right;
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1
right, Status*) {
- return left + (-right);
- }
-};
-
-struct SubtractDate32 {
- static constexpr int64_t kSecondsInDay = 86400;
-
- template <typename T, typename Arg0, typename Arg1>
- static constexpr T Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
- return arrow::internal::SafeSignedSubtract(left, right) * kSecondsInDay;
- }
-};
-
-struct SubtractCheckedDate32 {
- static constexpr int64_t kSecondsInDay = 86400;
-
- template <typename T, typename Arg0, typename Arg1>
- static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
- T result = 0;
- if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, right, &result))) {
- *st = Status::Invalid("overflow");
- }
- if (ARROW_PREDICT_FALSE(MultiplyWithOverflow(result, kSecondsInDay,
&result))) {
- *st = Status::Invalid("overflow");
- }
- return result;
- }
-};
-
-template <int64_t multiple>
-struct AddTimeDuration {
- template <typename T, typename Arg0, typename Arg1>
- static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
- T result =
- arrow::internal::SafeSignedAdd(static_cast<T>(left),
static_cast<T>(right));
- if (result < 0 || multiple <= result) {
- *st = Status::Invalid(result, " is not within the acceptable range of ",
"[0, ",
- multiple, ") s");
- }
- return result;
- }
-};
-
-template <int64_t multiple>
-struct AddTimeDurationChecked {
- template <typename T, typename Arg0, typename Arg1>
- static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
- T result = 0;
- if (ARROW_PREDICT_FALSE(
- AddWithOverflow(static_cast<T>(left), static_cast<T>(right),
&result))) {
- *st = Status::Invalid("overflow");
- }
- if (result < 0 || multiple <= result) {
- *st = Status::Invalid(result, " is not within the acceptable range of ",
"[0, ",
- multiple, ") s");
- }
- return result;
- }
-};
-
-template <int64_t multiple>
-struct SubtractTimeDuration {
- template <typename T, typename Arg0, typename Arg1>
- static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
- T result = arrow::internal::SafeSignedSubtract(left,
static_cast<T>(right));
- if (result < 0 || multiple <= result) {
- *st = Status::Invalid(result, " is not within the acceptable range of ",
"[0, ",
- multiple, ") s");
- }
- return result;
- }
-};
-
-template <int64_t multiple>
-struct SubtractTimeDurationChecked {
- template <typename T, typename Arg0, typename Arg1>
- static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
- T result = 0;
- if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, static_cast<T>(right),
&result))) {
- *st = Status::Invalid("overflow");
- }
- if (result < 0 || multiple <= result) {
- *st = Status::Invalid(result, " is not within the acceptable range of ",
"[0, ",
- multiple, ") s");
- }
- return result;
- }
-};
-
-struct Multiply {
- static_assert(std::is_same<decltype(int8_t() * int8_t()), int32_t>::value,
"");
- static_assert(std::is_same<decltype(uint8_t() * uint8_t()), int32_t>::value,
"");
- static_assert(std::is_same<decltype(int16_t() * int16_t()), int32_t>::value,
"");
- static_assert(std::is_same<decltype(uint16_t() * uint16_t()),
int32_t>::value, "");
- static_assert(std::is_same<decltype(int32_t() * int32_t()), int32_t>::value,
"");
- static_assert(std::is_same<decltype(uint32_t() * uint32_t()),
uint32_t>::value, "");
- static_assert(std::is_same<decltype(int64_t() * int64_t()), int64_t>::value,
"");
- static_assert(std::is_same<decltype(uint64_t() * uint64_t()),
uint64_t>::value, "");
-
- template <typename T, typename Arg0, typename Arg1>
- static constexpr enable_if_floating_value<T> Call(KernelContext*, T left, T
right,
- Status*) {
- return left * right;
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static constexpr enable_if_t<
- is_unsigned_integer_value<T>::value && !std::is_same<T,
uint16_t>::value, T>
- Call(KernelContext*, T left, T right, Status*) {
- return left * right;
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static constexpr enable_if_t<
- is_signed_integer_value<T>::value && !std::is_same<T, int16_t>::value, T>
- Call(KernelContext*, T left, T right, Status*) {
- return to_unsigned(left) * to_unsigned(right);
- }
-
- // Multiplication of 16 bit integer types implicitly promotes to signed 32
bit
- // integer. However, some inputs may nevertheless overflow (which triggers
undefined
- // behaviour). Therefore we first cast to 32 bit unsigned integers where
overflow is
- // well defined.
- template <typename T, typename Arg0, typename Arg1>
- static constexpr enable_if_same<T, int16_t, T> Call(KernelContext*, int16_t
left,
- int16_t right, Status*) {
- return static_cast<uint32_t>(left) * static_cast<uint32_t>(right);
- }
- template <typename T, typename Arg0, typename Arg1>
- static constexpr enable_if_same<T, uint16_t, T> Call(KernelContext*,
uint16_t left,
- uint16_t right,
Status*) {
- return static_cast<uint32_t>(left) * static_cast<uint32_t>(right);
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1
right, Status*) {
- return left * right;
- }
-};
-
-struct MultiplyChecked {
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_integer_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
- Status* st) {
- static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
- T result = 0;
- if (ARROW_PREDICT_FALSE(MultiplyWithOverflow(left, right, &result))) {
- *st = Status::Invalid("overflow");
- }
- return result;
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_floating_value<T> Call(KernelContext*, Arg0 left, Arg1
right,
- Status*) {
- static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
- return left * right;
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1
right, Status*) {
- return left * right;
- }
-};
-
-struct Divide {
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_floating_value<T> Call(KernelContext*, Arg0 left, Arg1
right,
- Status*) {
- return left / right;
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_integer_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
- Status* st) {
- T result;
- if (ARROW_PREDICT_FALSE(DivideWithOverflow(left, right, &result))) {
- if (right == 0) {
- *st = Status::Invalid("divide by zero");
- } else {
- result = 0;
- }
- }
- return result;
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
- Status* st) {
- if (right == Arg1()) {
- *st = Status::Invalid("Divide by zero");
- return T();
- } else {
- return left / right;
- }
- }
-};
-
-struct DivideChecked {
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_integer_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
- Status* st) {
- static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
- T result;
- if (ARROW_PREDICT_FALSE(DivideWithOverflow(left, right, &result))) {
- if (right == 0) {
- *st = Status::Invalid("divide by zero");
- } else {
- *st = Status::Invalid("overflow");
- }
- }
- return result;
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_floating_value<T> Call(KernelContext*, Arg0 left, Arg1
right,
- Status* st) {
- static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
- if (ARROW_PREDICT_FALSE(right == 0)) {
- *st = Status::Invalid("divide by zero");
- return 0;
- }
- return left / right;
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_decimal_value<T> Call(KernelContext* ctx, Arg0 left, Arg1
right,
- Status* st) {
- return Divide::Call<T>(ctx, left, right, st);
- }
-};
-
-struct Negate {
- template <typename T, typename Arg>
- static constexpr enable_if_floating_value<T> Call(KernelContext*, Arg arg,
Status*) {
- return -arg;
- }
-
- template <typename T, typename Arg>
- static constexpr enable_if_unsigned_integer_value<T> Call(KernelContext*,
Arg arg,
- Status*) {
- return ~arg + 1;
- }
-
- template <typename T, typename Arg>
- static constexpr enable_if_signed_integer_value<T> Call(KernelContext*, Arg
arg,
- Status*) {
- return arrow::internal::SafeSignedNegate(arg);
- }
-
- template <typename T, typename Arg>
- static constexpr enable_if_decimal_value<Arg, T> Call(KernelContext*, Arg
arg,
- Status*) {
- return arg.Negate();
- }
-};
-
-struct NegateChecked {
- template <typename T, typename Arg>
- static enable_if_signed_integer_value<Arg, T> Call(KernelContext*, Arg arg,
- Status* st) {
- static_assert(std::is_same<T, Arg>::value, "");
- T result = 0;
- if (ARROW_PREDICT_FALSE(NegateWithOverflow(arg, &result))) {
- *st = Status::Invalid("overflow");
- }
- return result;
- }
-
- template <typename T, typename Arg>
- static enable_if_unsigned_integer_value<Arg, T> Call(KernelContext* ctx, Arg
arg,
- Status* st) {
- static_assert(std::is_same<T, Arg>::value, "");
- DCHECK(false) << "This is included only for the purposes of
instantiability from the "
- "arithmetic kernel generator";
- return 0;
- }
-
- template <typename T, typename Arg>
- static constexpr enable_if_floating_value<Arg, T> Call(KernelContext*, Arg
arg,
- Status* st) {
- static_assert(std::is_same<T, Arg>::value, "");
- return -arg;
- }
-
- template <typename T, typename Arg>
- static constexpr enable_if_decimal_value<Arg, T> Call(KernelContext*, Arg
arg,
- Status*) {
- return arg.Negate();
- }
-};
-
-struct Power {
- ARROW_NOINLINE
- static uint64_t IntegerPower(uint64_t base, uint64_t exp) {
- // right to left O(logn) power
- uint64_t pow = 1;
- while (exp) {
- pow *= (exp & 1) ? base : 1;
- base *= base;
- exp >>= 1;
- }
- return pow;
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_integer_value<T> Call(KernelContext*, T base, T exp,
Status* st) {
- if (exp < 0) {
- *st = Status::Invalid("integers to negative integer powers are not
allowed");
- return 0;
- }
- return static_cast<T>(IntegerPower(base, exp));
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_floating_value<T> Call(KernelContext*, T base, T exp,
Status*) {
- return std::pow(base, exp);
- }
-};
-
-struct PowerChecked {
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_integer_value<T> Call(KernelContext*, Arg0 base, Arg1 exp,
- Status* st) {
- if (exp < 0) {
- *st = Status::Invalid("integers to negative integer powers are not
allowed");
- return 0;
- } else if (exp == 0) {
- return 1;
- }
- // left to right O(logn) power with overflow checks
- bool overflow = false;
- uint64_t bitmask =
- 1ULL << (63 - bit_util::CountLeadingZeros(static_cast<uint64_t>(exp)));
- T pow = 1;
- while (bitmask) {
- overflow |= MultiplyWithOverflow(pow, pow, &pow);
- if (exp & bitmask) {
- overflow |= MultiplyWithOverflow(pow, base, &pow);
- }
- bitmask >>= 1;
- }
- if (overflow) {
- *st = Status::Invalid("overflow");
- }
- return pow;
- }
-
- template <typename T, typename Arg0, typename Arg1>
- static enable_if_floating_value<T> Call(KernelContext*, Arg0 base, Arg1 exp,
Status*) {
- static_assert(std::is_same<T, Arg0>::value && std::is_same<T,
Arg1>::value, "");
- return std::pow(base, exp);
- }
-};
-
-struct SquareRoot {
- template <typename T, typename Arg>
- static enable_if_floating_value<Arg, T> Call(KernelContext*, Arg arg,
Status*) {
- static_assert(std::is_same<T, Arg>::value, "");
- if (arg < 0.0) {
- return std::numeric_limits<T>::quiet_NaN();
- }
- return std::sqrt(arg);
- }
-};
-
-struct SquareRootChecked {
- template <typename T, typename Arg>
- static enable_if_floating_value<Arg, T> Call(KernelContext*, Arg arg,
Status* st) {
- static_assert(std::is_same<T, Arg>::value, "");
- if (arg < 0.0) {
- *st = Status::Invalid("square root of negative number");
- return arg;
- }
- return std::sqrt(arg);
- }
-};
-
-struct Sign {
- template <typename T, typename Arg>
- static constexpr enable_if_floating_value<Arg, T> Call(KernelContext*, Arg
arg,
- Status*) {
- return std::isnan(arg) ? arg : ((arg == 0) ? 0 : (std::signbit(arg) ? -1 :
1));
- }
-
- template <typename T, typename Arg>
- static constexpr enable_if_unsigned_integer_value<Arg, T>
Call(KernelContext*, Arg arg,
- Status*) {
- return (arg > 0) ? 1 : 0;
- }
-
- template <typename T, typename Arg>
- static constexpr enable_if_signed_integer_value<Arg, T> Call(KernelContext*,
Arg arg,
- Status*) {
- return (arg > 0) ? 1 : ((arg == 0) ? 0 : -1);
- }
-
- template <typename T, typename Arg>
- static constexpr enable_if_decimal_value<Arg, T> Call(KernelContext*, Arg
arg,
- Status*) {
- return (arg == 0) ? 0 : arg.Sign();
- }
-};
-
// Bitwise operations
struct BitWiseNot {
@@ -1561,38 +1002,6 @@ struct Trunc {
}
};
-// Generate a kernel given an arithmetic functor
-template <template <typename... Args> class KernelGenerator, typename Op>
-ArrayKernelExec ArithmeticExecFromOp(detail::GetTypeId get_id) {
- switch (get_id.id) {
- case Type::INT8:
- return KernelGenerator<Int8Type, Int8Type, Op>::Exec;
- case Type::UINT8:
- return KernelGenerator<UInt8Type, UInt8Type, Op>::Exec;
- case Type::INT16:
- return KernelGenerator<Int16Type, Int16Type, Op>::Exec;
- case Type::UINT16:
- return KernelGenerator<UInt16Type, UInt16Type, Op>::Exec;
- case Type::INT32:
- return KernelGenerator<Int32Type, Int32Type, Op>::Exec;
- case Type::UINT32:
- return KernelGenerator<UInt32Type, UInt32Type, Op>::Exec;
- case Type::INT64:
- case Type::DURATION:
- case Type::TIMESTAMP:
- return KernelGenerator<Int64Type, Int64Type, Op>::Exec;
- case Type::UINT64:
- return KernelGenerator<UInt64Type, UInt64Type, Op>::Exec;
- case Type::FLOAT:
- return KernelGenerator<FloatType, FloatType, Op>::Exec;
- case Type::DOUBLE:
- return KernelGenerator<DoubleType, DoubleType, Op>::Exec;
- default:
- DCHECK(false);
- return ExecFail;
- }
-}
-
// Generate a kernel given a bitwise arithmetic functor. Assumes the
// functor treats all integer types of equal width identically
template <template <typename... Args> class KernelGenerator, typename Op>
diff --git a/cpp/src/arrow/compute/kernels/vector_cumulative_ops.cc
b/cpp/src/arrow/compute/kernels/vector_cumulative_ops.cc
new file mode 100644
index 0000000000..c0eb40964d
--- /dev/null
+++ b/cpp/src/arrow/compute/kernels/vector_cumulative_ops.cc
@@ -0,0 +1,235 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/cast.h"
+#include "arrow/compute/kernels/base_arithmetic_internal.h"
+#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/result.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/visit_type_inline.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+namespace {
+template <typename OptionsType>
+struct CumulativeOptionsWrapper : public OptionsWrapper<OptionsType> {
+ using State = CumulativeOptionsWrapper<OptionsType>;
+
+ explicit CumulativeOptionsWrapper(OptionsType options)
+ : OptionsWrapper<OptionsType>(std::move(options)) {}
+
+ static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx,
+ const KernelInitArgs& args)
{
+ auto options = checked_cast<const OptionsType*>(args.options);
+ if (!options) {
+ return Status::Invalid(
+ "Attempted to initialize KernelState from null FunctionOptions");
+ }
+
+ const auto& start = options->start;
+ if (!start || !start->is_valid) {
+ return Status::Invalid("Cumulative `start` option must be non-null and
valid");
+ }
+
+ // Ensure `start` option matches input type
+ if (!start->type->Equals(args.inputs[0].type)) {
+ ARROW_ASSIGN_OR_RAISE(auto casted_start,
+ Cast(Datum(start), args.inputs[0].type,
CastOptions::Safe(),
+ ctx->exec_context()));
+ auto new_options = OptionsType(casted_start.scalar(),
options->skip_nulls);
+ return ::arrow::internal::make_unique<State>(new_options);
+ }
+ return ::arrow::internal::make_unique<State>(*options);
+ }
+};
+
+// The driver kernel for all cumulative compute functions. Op is a compute
kernel
+// representing any binary associative operation (add, product, min, max,
etc.) and
+// OptionsType the options type corresponding to Op. ArgType and OutType are
the input
+// and output types, which will normally be the same (e.g. the cumulative sum
of an array
+// of Int64Type will result in an array of Int64Type).
+template <typename OutType, typename ArgType, typename Op, typename
OptionsType>
+struct CumulativeGeneric {
+ using OutValue = typename GetOutputType<OutType>::T;
+ using ArgValue = typename GetViewType<ArgType>::T;
+
+ KernelContext* ctx;
+ ArgValue accumulator;
+ bool skip_nulls;
+ bool encountered_null = false;
+ Datum values;
+ NumericBuilder<OutType>* builder;
+
+ Status Cumulate(std::shared_ptr<ArrayData>* out_arr) {
+ switch (values.kind()) {
+ case Datum::SCALAR: {
+ auto in_scalar = values.scalar();
+ if (!skip_nulls && !in_scalar->is_valid) {
+ RETURN_NOT_OK(builder->AppendNull());
+ break;
+ }
+
+ if (skip_nulls && !in_scalar->is_valid) {
+ RETURN_NOT_OK(builder->Append(accumulator));
+ break;
+ }
+
+ Status st;
+ auto in_value = UnboxScalar<OutType>::Unbox(*(in_scalar));
+ auto result = Op::template Call<OutValue, ArgValue, ArgValue>(ctx,
accumulator,
+
in_value, &st);
+ RETURN_NOT_OK(st);
+ RETURN_NOT_OK(builder->Append(result));
+ break;
+ }
+ case Datum::ARRAY: {
+ auto arr_input = values.array();
+ RETURN_NOT_OK(builder->Reserve(arr_input->length));
+ RETURN_NOT_OK(Call(*arr_input));
+ break;
+ }
+ case Datum::CHUNKED_ARRAY: {
+ const auto& chunked_input = values.chunked_array();
+ RETURN_NOT_OK(builder->Reserve(chunked_input->length()));
+
+ for (const auto& chunk : chunked_input->chunks()) {
+ RETURN_NOT_OK(Call(*chunk->data()));
+ }
+ break;
+ }
+ default:
+ return Status::NotImplemented(
+ "Unsupported input type for function 'cumulative_<operator>': ",
+ values.ToString());
+ }
+
+ RETURN_NOT_OK(builder->FinishInternal(out_arr));
+ return Status::OK();
+ }
+
+ Status Call(const ArrayData& input) {
+ Status st = Status::OK();
+
+ if (skip_nulls || (input.GetNullCount() == 0 && !encountered_null)) {
+ VisitArrayValuesInline<ArgType>(
+ input,
+ [&](ArgValue v) {
+ accumulator =
+ Op::template Call<OutValue, ArgValue, ArgValue>(ctx, v,
accumulator, &st);
+ builder->UnsafeAppend(accumulator);
+ },
+ [&]() { builder->UnsafeAppendNull(); });
+ } else {
+ int64_t nulls_start_idx = 0;
+ VisitArrayValuesInline<ArgType>(
+ input,
+ [&](ArgValue v) {
+ if (!encountered_null) {
+ accumulator = Op::template Call<OutValue, ArgValue, ArgValue>(
+ ctx, v, accumulator, &st);
+ builder->UnsafeAppend(accumulator);
+ ++nulls_start_idx;
+ }
+ },
+ [&]() { encountered_null = true; });
+
+ RETURN_NOT_OK(builder->AppendNulls(input.length - nulls_start_idx));
+ }
+
+ return st;
+ }
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& options = CumulativeOptionsWrapper<OptionsType>::Get(ctx);
+
+ auto start = UnboxScalar<OutType>::Unbox(*(options.start));
+ auto skip_nulls = options.skip_nulls;
+ NumericBuilder<OutType> builder(ctx->memory_pool());
+
+ CumulativeGeneric self;
+ self.ctx = ctx;
+ self.accumulator = start;
+ self.skip_nulls = skip_nulls;
+ self.values = batch[0];
+ self.builder = &builder;
+
+ std::shared_ptr<ArrayData> out_arr;
+ RETURN_NOT_OK(self.Cumulate(&out_arr));
+ out->value = std::move(out_arr);
+ return Status::OK();
+ }
+};
+
+const FunctionDoc cumulative_sum_doc{
+ "Compute the cumulative sum over a numeric input",
+ ("`values` must be numeric. Return an array/chunked array which is the\n"
+ "cumulative sum computed over `values`. Results will wrap around on\n"
+ "integer overflow. Use function \"cumulative_sum_checked\" if you want\n"
+ "overflow to return an error."),
+ {"values"},
+ "CumulativeSumOptions"};
+
+const FunctionDoc cumulative_sum_checked_doc{
+ "Compute the cumulative sum over a numeric input",
+ ("`values` must be numeric. Return an array/chunked array which is the\n"
+ "cumulative sum computed over `values`. This function returns an error\n"
+ "on overflow. For a variant that doesn't fail on overflow, use\n"
+ "function \"cumulative_sum\"."),
+ {"values"},
+ "CumulativeSumOptions"};
+} // namespace
+
+template <typename Op, typename OptionsType>
+void MakeVectorCumulativeFunction(FunctionRegistry* registry, const
std::string func_name,
+ const FunctionDoc doc) {
+ static const OptionsType kDefaultOptions = OptionsType::Defaults();
+ auto func =
+ std::make_shared<VectorFunction>(func_name, Arity::Unary(), doc,
&kDefaultOptions);
+
+ std::vector<std::shared_ptr<DataType>> types;
+ types.insert(types.end(), NumericTypes().begin(), NumericTypes().end());
+
+ for (const auto& ty : types) {
+ VectorKernel kernel;
+ kernel.can_execute_chunkwise = false;
+ kernel.null_handling = NullHandling::type::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::type::NO_PREALLOCATE;
+ kernel.signature = KernelSignature::Make({InputType(ty)}, OutputType(ty));
+ kernel.exec = ArithmeticExecFromOp<CumulativeGeneric, Op, OptionsType>(ty);
+ kernel.init = CumulativeOptionsWrapper<OptionsType>::Init;
+ DCHECK_OK(func->AddKernel(std::move(kernel)));
+ }
+
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
+void RegisterVectorCumulativeSum(FunctionRegistry* registry) {
+ MakeVectorCumulativeFunction<Add, CumulativeSumOptions>(registry,
"cumulative_sum",
+ cumulative_sum_doc);
+ MakeVectorCumulativeFunction<AddChecked, CumulativeSumOptions>(
+ registry, "cumulative_sum_checked", cumulative_sum_checked_doc);
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/vector_cumulative_ops_test.cc
b/cpp/src/arrow/compute/kernels/vector_cumulative_ops_test.cc
new file mode 100644
index 0000000000..3485ffffb4
--- /dev/null
+++ b/cpp/src/arrow/compute/kernels/vector_cumulative_ops_test.cc
@@ -0,0 +1,357 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/chunked_array.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/type.h"
+
+#include "arrow/array/builder_primitive.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/kernels/test_util.h"
+
+namespace arrow {
+namespace compute {
+
+TEST(TestCumulativeSum, Empty) {
+ CumulativeSumOptions options;
+ for (auto ty : NumericTypes()) {
+ auto empty_arr = ArrayFromJSON(ty, "[]");
+ auto empty_chunked = ChunkedArrayFromJSON(ty, {"[]"});
+ CheckVectorUnary("cumulative_sum", empty_arr, empty_arr, &options);
+ CheckVectorUnary("cumulative_sum_checked", empty_arr, empty_arr, &options);
+
+ CheckVectorUnary("cumulative_sum", empty_chunked, empty_chunked, &options);
+ CheckVectorUnary("cumulative_sum_checked", empty_chunked, empty_chunked,
&options);
+ }
+}
+
+TEST(TestCumulativeSum, AllNulls) {
+ CumulativeSumOptions options;
+ for (auto ty : NumericTypes()) {
+ auto nulls_arr = ArrayFromJSON(ty, "[null, null, null]");
+ auto nulls_one_chunk = ChunkedArrayFromJSON(ty, {"[null, null, null]"});
+ auto nulls_three_chunks = ChunkedArrayFromJSON(ty, {"[null]", "[null]",
"[null]"});
+ CheckVectorUnary("cumulative_sum", nulls_arr, nulls_arr, &options);
+ CheckVectorUnary("cumulative_sum_checked", nulls_arr, nulls_arr, &options);
+
+ CheckVectorUnary("cumulative_sum", nulls_one_chunk, nulls_one_chunk,
&options);
+ CheckVectorUnary("cumulative_sum_checked", nulls_one_chunk,
nulls_one_chunk,
+ &options);
+
+ CheckVectorUnary("cumulative_sum", nulls_three_chunks, nulls_one_chunk,
&options);
+ CheckVectorUnary("cumulative_sum_checked", nulls_three_chunks,
nulls_one_chunk,
+ &options);
+ }
+}
+
+TEST(TestCumulativeSum, ScalarInput) {
+ CumulativeSumOptions no_start_no_skip;
+ CumulativeSumOptions no_start_do_skip(0, true);
+ CumulativeSumOptions has_start_no_skip(10);
+ CumulativeSumOptions has_start_do_skip(10, true);
+
+ for (auto ty : NumericTypes()) {
+ CheckVectorUnary("cumulative_sum", ScalarFromJSON(ty, "10"),
+ ArrayFromJSON(ty, "[10]"), &no_start_no_skip);
+ CheckVectorUnary("cumulative_sum_checked", ScalarFromJSON(ty, "10"),
+ ArrayFromJSON(ty, "[10]"), &no_start_no_skip);
+
+ CheckVectorUnary("cumulative_sum", ScalarFromJSON(ty, "10"),
+ ArrayFromJSON(ty, "[20]"), &has_start_no_skip);
+ CheckVectorUnary("cumulative_sum_checked", ScalarFromJSON(ty, "10"),
+ ArrayFromJSON(ty, "[20]"), &has_start_no_skip);
+
+ CheckVectorUnary("cumulative_sum", ScalarFromJSON(ty, "null"),
+ ArrayFromJSON(ty, "[null]"), &no_start_no_skip);
+ CheckVectorUnary("cumulative_sum_checked", ScalarFromJSON(ty, "null"),
+ ArrayFromJSON(ty, "[null]"), &no_start_no_skip);
+ CheckVectorUnary("cumulative_sum", ScalarFromJSON(ty, "null"),
+ ArrayFromJSON(ty, "[null]"), &has_start_no_skip);
+ CheckVectorUnary("cumulative_sum_checked", ScalarFromJSON(ty, "null"),
+ ArrayFromJSON(ty, "[null]"), &has_start_no_skip);
+
+ CheckVectorUnary("cumulative_sum", ScalarFromJSON(ty, "null"),
+ ArrayFromJSON(ty, "[0]"), &no_start_do_skip);
+ CheckVectorUnary("cumulative_sum_checked", ScalarFromJSON(ty, "null"),
+ ArrayFromJSON(ty, "[0]"), &no_start_do_skip);
+ CheckVectorUnary("cumulative_sum", ScalarFromJSON(ty, "null"),
+ ArrayFromJSON(ty, "[10]"), &has_start_do_skip);
+ CheckVectorUnary("cumulative_sum_checked", ScalarFromJSON(ty, "null"),
+ ArrayFromJSON(ty, "[10]"), &has_start_do_skip);
+ }
+}
+
+using testing::HasSubstr;
+
+template <typename ArrowType>
+void CheckCumulativeSumUnsignedOverflow() {
+ using CType = typename TypeTraits<ArrowType>::CType;
+ using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
+ using BuilderType = typename TypeTraits<ArrowType>::BuilderType;
+
+ CumulativeSumOptions pos_overflow(1);
+ auto max = std::numeric_limits<CType>::max();
+ auto min = std::numeric_limits<CType>::lowest();
+
+ BuilderType builder;
+ std::shared_ptr<Array> max_arr;
+ std::shared_ptr<Array> min_arr;
+ ASSERT_OK(builder.Append(max));
+ ASSERT_OK(builder.Finish(&max_arr));
+ builder.Reset();
+ ASSERT_OK(builder.Append(min));
+ ASSERT_OK(builder.Finish(&min_arr));
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, HasSubstr("overflow"),
+ CallFunction("cumulative_sum_checked", {ScalarType(max)},
&pos_overflow));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, HasSubstr("overflow"),
+ CallFunction("cumulative_sum_checked", {max_arr}, &pos_overflow));
+ CheckVectorUnary("cumulative_sum", ScalarType(max), min_arr, &pos_overflow);
+}
+
+template <typename ArrowType>
+void CheckCumulativeSumSignedOverflow() {
+ using CType = typename TypeTraits<ArrowType>::CType;
+ using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
+ using BuilderType = typename TypeTraits<ArrowType>::BuilderType;
+
+ CheckCumulativeSumUnsignedOverflow<ArrowType>();
+
+ CumulativeSumOptions neg_overflow(-1);
+ auto max = std::numeric_limits<CType>::max();
+ auto min = std::numeric_limits<CType>::lowest();
+
+ BuilderType builder;
+ std::shared_ptr<Array> max_arr;
+ std::shared_ptr<Array> min_arr;
+ ASSERT_OK(builder.Append(max));
+ ASSERT_OK(builder.Finish(&max_arr));
+ builder.Reset();
+ ASSERT_OK(builder.Append(min));
+ ASSERT_OK(builder.Finish(&min_arr));
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, HasSubstr("overflow"),
+ CallFunction("cumulative_sum_checked", {ScalarType(min)},
&neg_overflow));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, HasSubstr("overflow"),
+ CallFunction("cumulative_sum_checked", {min_arr}, &neg_overflow));
+ CheckVectorUnary("cumulative_sum", ScalarType(min), max_arr, &neg_overflow);
+}
+
+TEST(TestCumulativeSum, IntegerOverflow) {
+ CheckCumulativeSumUnsignedOverflow<UInt8Type>();
+ CheckCumulativeSumUnsignedOverflow<UInt16Type>();
+ CheckCumulativeSumUnsignedOverflow<UInt32Type>();
+ CheckCumulativeSumUnsignedOverflow<UInt64Type>();
+ CheckCumulativeSumSignedOverflow<Int8Type>();
+ CheckCumulativeSumSignedOverflow<Int16Type>();
+ CheckCumulativeSumSignedOverflow<Int32Type>();
+ CheckCumulativeSumSignedOverflow<Int64Type>();
+}
+
+TEST(TestCumulativeSum, NoStartNoSkip) {
+ CumulativeSumOptions options;
+ for (auto ty : NumericTypes()) {
+ CheckVectorUnary("cumulative_sum", ArrayFromJSON(ty, "[1, 2, 3, 4, 5, 6]"),
+ ArrayFromJSON(ty, "[1, 3, 6, 10, 15, 21]"), &options);
+ CheckVectorUnary("cumulative_sum_checked", ArrayFromJSON(ty, "[1, 2, 3, 4,
5, 6]"),
+ ArrayFromJSON(ty, "[1, 3, 6, 10, 15, 21]"), &options);
+
+ CheckVectorUnary("cumulative_sum", ArrayFromJSON(ty, "[1, 2, null, 4,
null, 6]"),
+ ArrayFromJSON(ty, "[1, 3, null, null, null, null]"),
&options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ArrayFromJSON(ty, "[1, 2, null, 4, null, 6]"),
+ ArrayFromJSON(ty, "[1, 3, null, null, null, null]"),
&options);
+
+ CheckVectorUnary("cumulative_sum", ArrayFromJSON(ty, "[null, 2, null, 4,
null, 6]"),
+ ArrayFromJSON(ty, "[null, null, null, null, null,
null]"), &options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ArrayFromJSON(ty, "[null, 2, null, 4, null, 6]"),
+ ArrayFromJSON(ty, "[null, null, null, null, null,
null]"), &options);
+
+ CheckVectorUnary("cumulative_sum",
+ ChunkedArrayFromJSON(ty, {"[1, 2, 3]", "[4, 5, 6]"}),
+ ChunkedArrayFromJSON(ty, {"[1, 3, 6, 10, 15, 21]"}),
&options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ChunkedArrayFromJSON(ty, {"[1, 2, 3]", "[4, 5, 6]"}),
+ ChunkedArrayFromJSON(ty, {"[1, 3, 6, 10, 15, 21]"}),
&options);
+
+ CheckVectorUnary(
+ "cumulative_sum", ChunkedArrayFromJSON(ty, {"[1, 2, null]", "[4, null,
6]"}),
+ ChunkedArrayFromJSON(ty, {"[1, 3, null, null, null, null]"}),
&options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ChunkedArrayFromJSON(ty, {"[1, 2, null]", "[4, null,
6]"}),
+ ChunkedArrayFromJSON(ty, {"[1, 3, null, null, null,
null]"}),
+ &options);
+
+ CheckVectorUnary(
+ "cumulative_sum", ChunkedArrayFromJSON(ty, {"[null, 2, null]", "[4,
null, 6]"}),
+ ChunkedArrayFromJSON(ty, {"[null, null, null, null, null, null]"}),
&options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ChunkedArrayFromJSON(ty, {"[null, 2, null]", "[4, null,
6]"}),
+ ChunkedArrayFromJSON(ty, {"[null, null, null, null, null,
null]"}),
+ &options);
+ }
+}
+
+TEST(TestCumulativeSum, NoStartDoSkip) {
+ CumulativeSumOptions options(0, true);
+ for (auto ty : NumericTypes()) {
+ CheckVectorUnary("cumulative_sum", ArrayFromJSON(ty, "[1, 2, 3, 4, 5, 6]"),
+ ArrayFromJSON(ty, "[1, 3, 6, 10, 15, 21]"), &options);
+ CheckVectorUnary("cumulative_sum_checked", ArrayFromJSON(ty, "[1, 2, 3, 4,
5, 6]"),
+ ArrayFromJSON(ty, "[1, 3, 6, 10, 15, 21]"), &options);
+
+ CheckVectorUnary("cumulative_sum", ArrayFromJSON(ty, "[1, 2, null, 4,
null, 6]"),
+ ArrayFromJSON(ty, "[1, 3, null, 7, null, 13]"), &options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ArrayFromJSON(ty, "[1, 2, null, 4, null, 6]"),
+ ArrayFromJSON(ty, "[1, 3, null, 7, null, 13]"), &options);
+
+ CheckVectorUnary("cumulative_sum", ArrayFromJSON(ty, "[null, 2, null, 4,
null, 6]"),
+ ArrayFromJSON(ty, "[null, 2, null, 6, null, 12]"),
&options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ArrayFromJSON(ty, "[null, 2, null, 4, null, 6]"),
+ ArrayFromJSON(ty, "[null, 2, null, 6, null, 12]"),
&options);
+
+ CheckVectorUnary("cumulative_sum",
+ ChunkedArrayFromJSON(ty, {"[1, 2, 3]", "[4, 5, 6]"}),
+ ChunkedArrayFromJSON(ty, {"[1, 3, 6, 10, 15, 21]"}),
&options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ChunkedArrayFromJSON(ty, {"[1, 2, 3]", "[4, 5, 6]"}),
+ ChunkedArrayFromJSON(ty, {"[1, 3, 6, 10, 15, 21]"}),
&options);
+
+ CheckVectorUnary("cumulative_sum",
+ ChunkedArrayFromJSON(ty, {"[1, 2, null]", "[4, null,
6]"}),
+ ChunkedArrayFromJSON(ty, {"[1, 3, null, 7, null, 13]"}),
&options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ChunkedArrayFromJSON(ty, {"[1, 2, null]", "[4, null,
6]"}),
+ ChunkedArrayFromJSON(ty, {"[1, 3, null, 7, null, 13]"}),
&options);
+
+ CheckVectorUnary(
+ "cumulative_sum", ChunkedArrayFromJSON(ty, {"[null, 2, null]", "[4,
null, 6]"}),
+ ChunkedArrayFromJSON(ty, {"[null, 2, null, 6, null, 12]"}), &options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ChunkedArrayFromJSON(ty, {"[null, 2, null]", "[4, null,
6]"}),
+ ChunkedArrayFromJSON(ty, {"[null, 2, null, 6, null,
12]"}),
+ &options);
+ }
+}
+
+TEST(TestCumulativeSum, HasStartNoSkip) {
+ CumulativeSumOptions options(10);
+ for (auto ty : NumericTypes()) {
+ CheckVectorUnary("cumulative_sum", ArrayFromJSON(ty, "[1, 2, 3, 4, 5, 6]"),
+ ArrayFromJSON(ty, "[11, 13, 16, 20, 25, 31]"), &options);
+ CheckVectorUnary("cumulative_sum_checked", ArrayFromJSON(ty, "[1, 2, 3, 4,
5, 6]"),
+ ArrayFromJSON(ty, "[11, 13, 16, 20, 25, 31]"), &options);
+
+ CheckVectorUnary("cumulative_sum", ArrayFromJSON(ty, "[1, 2, null, 4,
null, 6]"),
+ ArrayFromJSON(ty, "[11, 13, null, null, null, null]"),
&options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ArrayFromJSON(ty, "[1, 2, null, 4, null, 6]"),
+ ArrayFromJSON(ty, "[11, 13, null, null, null, null]"),
&options);
+
+ CheckVectorUnary("cumulative_sum", ArrayFromJSON(ty, "[null, 2, null, 4,
null, 6]"),
+ ArrayFromJSON(ty, "[null, null, null, null, null,
null]"), &options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ArrayFromJSON(ty, "[null, 2, null, 4, null, 6]"),
+ ArrayFromJSON(ty, "[null, null, null, null, null,
null]"), &options);
+
+ CheckVectorUnary("cumulative_sum",
+ ChunkedArrayFromJSON(ty, {"[1, 2, 3]", "[4, 5, 6]"}),
+ ChunkedArrayFromJSON(ty, {"[11, 13, 16, 20, 25, 31]"}),
&options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ChunkedArrayFromJSON(ty, {"[1, 2, 3]", "[4, 5, 6]"}),
+ ChunkedArrayFromJSON(ty, {"[11, 13, 16, 20, 25, 31]"}),
&options);
+
+ CheckVectorUnary(
+ "cumulative_sum", ChunkedArrayFromJSON(ty, {"[1, 2, null]", "[4, null,
6]"}),
+ ChunkedArrayFromJSON(ty, {"[11, 13, null, null, null, null]"}),
&options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ChunkedArrayFromJSON(ty, {"[1, 2, null]", "[4, null,
6]"}),
+ ChunkedArrayFromJSON(ty, {"[11, 13, null, null, null,
null]"}),
+ &options);
+
+ CheckVectorUnary(
+ "cumulative_sum", ChunkedArrayFromJSON(ty, {"[null, 2, null]", "[4,
null, 6]"}),
+ ChunkedArrayFromJSON(ty, {"[null, null, null, null, null, null]"}),
&options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ChunkedArrayFromJSON(ty, {"[null, 2, null]", "[4, null,
6]"}),
+ ChunkedArrayFromJSON(ty, {"[null, null, null, null, null,
null]"}),
+ &options);
+ }
+}
+
+TEST(TestCumulativeSum, HasStartDoSkip) {
+ CumulativeSumOptions options(10, true);
+ for (auto ty : NumericTypes()) {
+ CheckVectorUnary("cumulative_sum", ArrayFromJSON(ty, "[1, 2, 3, 4, 5, 6]"),
+ ArrayFromJSON(ty, "[11, 13, 16, 20, 25, 31]"), &options);
+ CheckVectorUnary("cumulative_sum_checked", ArrayFromJSON(ty, "[1, 2, 3, 4,
5, 6]"),
+ ArrayFromJSON(ty, "[11, 13, 16, 20, 25, 31]"), &options);
+
+ CheckVectorUnary("cumulative_sum", ArrayFromJSON(ty, "[1, 2, null, 4,
null, 6]"),
+ ArrayFromJSON(ty, "[11, 13, null, 17, null, 23]"),
&options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ArrayFromJSON(ty, "[1, 2, null, 4, null, 6]"),
+ ArrayFromJSON(ty, "[11, 13, null, 17, null, 23]"),
&options);
+
+ CheckVectorUnary("cumulative_sum", ArrayFromJSON(ty, "[null, 2, null, 4,
null, 6]"),
+ ArrayFromJSON(ty, "[null, 12, null, 16, null, 22]"),
&options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ArrayFromJSON(ty, "[null, 2, null, 4, null, 6]"),
+ ArrayFromJSON(ty, "[null, 12, null, 16, null, 22]"),
&options);
+
+ CheckVectorUnary("cumulative_sum",
+ ChunkedArrayFromJSON(ty, {"[1, 2, 3]", "[4, 5, 6]"}),
+ ChunkedArrayFromJSON(ty, {"[11, 13, 16, 20, 25, 31]"}),
&options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ChunkedArrayFromJSON(ty, {"[1, 2, 3]", "[4, 5, 6]"}),
+ ChunkedArrayFromJSON(ty, {"[11, 13, 16, 20, 25, 31]"}),
&options);
+
+ CheckVectorUnary(
+ "cumulative_sum", ChunkedArrayFromJSON(ty, {"[1, 2, null]", "[4, null,
6]"}),
+ ChunkedArrayFromJSON(ty, {"[11, 13, null, 17, null, 23]"}), &options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ChunkedArrayFromJSON(ty, {"[1, 2, null]", "[4, null,
6]"}),
+ ChunkedArrayFromJSON(ty, {"[11, 13, null, 17, null,
23]"}),
+ &options);
+
+ CheckVectorUnary(
+ "cumulative_sum", ChunkedArrayFromJSON(ty, {"[null, 2, null]", "[4,
null, 6]"}),
+ ChunkedArrayFromJSON(ty, {"[null, 12, null, 16, null, 22]"}),
&options);
+ CheckVectorUnary("cumulative_sum_checked",
+ ChunkedArrayFromJSON(ty, {"[null, 2, null]", "[4, null,
6]"}),
+ ChunkedArrayFromJSON(ty, {"[null, 12, null, 16, null,
22]"}),
+ &options);
+ }
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc
b/cpp/src/arrow/compute/kernels/vector_sort.cc
index 88c8a193fd..8a108621d3 100644
--- a/cpp/src/arrow/compute/kernels/vector_sort.cc
+++ b/cpp/src/arrow/compute/kernels/vector_sort.cc
@@ -1854,18 +1854,14 @@ class SelectKUnstableMetaFunction : public MetaFunction
{
return Status::Invalid("select_k_unstable requires a non-empty
`sort_keys`");
}
switch (args[0].kind()) {
- case Datum::ARRAY: {
+ case Datum::ARRAY:
return SelectKth(*args[0].make_array(), select_k_options, ctx);
- } break;
- case Datum::CHUNKED_ARRAY: {
+ case Datum::CHUNKED_ARRAY:
return SelectKth(*args[0].chunked_array(), select_k_options, ctx);
- } break;
case Datum::RECORD_BATCH:
return SelectKth(*args[0].record_batch(), select_k_options, ctx);
- break;
case Datum::TABLE:
return SelectKth(*args[0].table(), select_k_options, ctx);
- break;
default:
break;
}
diff --git a/cpp/src/arrow/compute/registry.cc
b/cpp/src/arrow/compute/registry.cc
index 8ab83a72e5..7e1975d3b6 100644
--- a/cpp/src/arrow/compute/registry.cc
+++ b/cpp/src/arrow/compute/registry.cc
@@ -174,6 +174,7 @@ static std::unique_ptr<FunctionRegistry>
CreateBuiltInRegistry() {
// Vector functions
RegisterVectorArraySort(registry.get());
+ RegisterVectorCumulativeSum(registry.get());
RegisterVectorHash(registry.get());
RegisterVectorNested(registry.get());
RegisterVectorReplace(registry.get());
diff --git a/cpp/src/arrow/compute/registry_internal.h
b/cpp/src/arrow/compute/registry_internal.h
index 35f7b07952..38f81e9888 100644
--- a/cpp/src/arrow/compute/registry_internal.h
+++ b/cpp/src/arrow/compute/registry_internal.h
@@ -43,6 +43,7 @@ void RegisterScalarOptions(FunctionRegistry* registry);
// Vector functions
void RegisterVectorArraySort(FunctionRegistry* registry);
+void RegisterVectorCumulativeSum(FunctionRegistry* registry);
void RegisterVectorHash(FunctionRegistry* registry);
void RegisterVectorNested(FunctionRegistry* registry);
void RegisterVectorReplace(FunctionRegistry* registry);
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 50977b750c..c373fa8988 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -1551,6 +1551,32 @@ random generator.
Array-wise ("vector") functions
-------------------------------
+Cumulative Functions
+~~~~~~~~~~~~~~~~~~~~
+
+Cumulative functions are vector functions that perform a running total on their
+input using an given binary associatve operation and output an array containing
+the corresponding intermediate running values. The input is expected to be of
+numeric type. By default these functions do not detect overflow. They are also
+available in an overflow-checking variant, suffixed ``_checked``, which returns
+an ``Invalid`` :class:`Status` when overflow is detected.
+
++------------------------+-------+-------------+-------------+--------------------------------+-------+
+| Function name | Arity | Input types | Output type | Options class
| Notes |
++========================+=======+=============+=============+================================+=======+
+| cumulative_sum | Unary | Numeric | Numeric |
:struct:`CumulativeSumOptions` | \(1) |
++------------------------+-------+-------------+-------------+--------------------------------+-------+
+| cumulative_sum_checked | Unary | Numeric | Numeric |
:struct:`CumulativeSumOptions` | \(1) |
++------------------------+-------+-------------+-------------+--------------------------------+-------+
+
+* \(1) CumulativeSumOptions has two optional parameters. The first parameter
+ :member:`CumulativeSumOptions::start` is a starting value for the running
+ sum. It has a default value of 0. Specified values of ``start`` must have the
+ same type as the input. The second parameter
+ :member:`CumulativeSumOptions::skip_nulls` is a boolean. When set to
+ false (the default), the first encountered null is propagated. When set to
+ true, each null in the input produces a corresponding null in the output.
+
Associative transforms
~~~~~~~~~~~~~~~~~~~~~~
diff --git a/docs/source/python/api/compute.rst
b/docs/source/python/api/compute.rst
index 3d52f48ed8..4a9208fd31 100644
--- a/docs/source/python/api/compute.rst
+++ b/docs/source/python/api/compute.rst
@@ -45,6 +45,21 @@ Aggregations
tdigest
variance
+Cumulative Functions
+--------------------
+
+Cumulative functions are vector functions that perform a running total on their
+input and output an array containing the corresponding intermediate running
values.
+By default these functions do not detect overflow. They are also
+available in an overflow-checking variant, suffixed ``_checked``, which
+throws an ``ArrowInvalid`` exception when overflow is detected.
+
+.. autosummary::
+ :toctree: ../generated/
+
+ cumulative_sum
+ cumulative_sum_checked
+
Arithmetic Functions
--------------------
@@ -502,6 +517,7 @@ Compute Options
CastOptions
CountOptions
CountOptions
+ CumulativeSumOptions
DayOfWeekOptions
DictionaryEncodeOptions
ElementWiseAggregateOptions
diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index 96da505f76..e74404a771 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -1748,6 +1748,34 @@ class PartitionNthOptions(_PartitionNthOptions):
self._set_options(pivot, null_placement)
+cdef class _CumulativeSumOptions(FunctionOptions):
+ def _set_options(self, start, skip_nulls):
+ if not isinstance(start, Scalar):
+ try:
+ start = lib.scalar(start)
+ except Exception:
+ _raise_invalid_function_option(
+ start, "`start` type for CumulativeSumOptions", TypeError)
+
+ self.wrapped.reset(new CCumulativeSumOptions((<Scalar>
start).unwrap(), skip_nulls))
+
+
+class CumulativeSumOptions(_CumulativeSumOptions):
+ """
+ Options for `cumulative_sum` function.
+
+ Parameters
+ ----------
+ start : Scalar, default 0.0
+ Starting value for sum computation
+ skip_nulls : bool, default False
+ When false, the first encountered null is propagated.
+ """
+
+ def __init__(self, start=0.0, *, skip_nulls=False):
+ self._set_options(start, skip_nulls)
+
+
cdef class _ArraySortOptions(FunctionOptions):
def _set_options(self, order, null_placement):
self.wrapped.reset(new CArraySortOptions(
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index e10536dd10..b89030004e 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -33,6 +33,7 @@ from pyarrow._compute import ( # noqa
AssumeTimezoneOptions,
CastOptions,
CountOptions,
+ CumulativeSumOptions,
DayOfWeekOptions,
DictionaryEncodeOptions,
ElementWiseAggregateOptions,
diff --git a/python/pyarrow/includes/libarrow.pxd
b/python/pyarrow/includes/libarrow.pxd
index 2e51864b86..fe93ec9a2f 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -2245,6 +2245,12 @@ cdef extern from "arrow/compute/api.h" namespace
"arrow::compute" nogil:
int64_t pivot
CNullPlacement null_placement
+ cdef cppclass CCumulativeSumOptions \
+ "arrow::compute::CumulativeSumOptions"(CFunctionOptions):
+ CCumulativeSumOptions(shared_ptr[CScalar] start, c_bool skip_nulls)
+ shared_ptr[CScalar] start
+ c_bool skip_nulls
+
cdef cppclass CArraySortOptions \
"arrow::compute::ArraySortOptions"(CFunctionOptions):
CArraySortOptions(CSortOrder, CNullPlacement)
diff --git a/python/pyarrow/tests/test_compute.py
b/python/pyarrow/tests/test_compute.py
index 45282a2867..2afee0c2d9 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -147,6 +147,7 @@ def test_option_class_equality():
pc.NullOptions(),
pc.PadOptions(5),
pc.PartitionNthOptions(1, null_placement="at_start"),
+ pc.CumulativeSumOptions(start=0, skip_nulls=False),
pc.QuantileOptions(),
pc.RandomOptions(10),
pc.ReplaceSliceOptions(0, 1, "a"),
@@ -2510,6 +2511,58 @@ def test_min_max_element_wise():
assert result == pa.array([1, 2, None])
[email protected]('start', (1.25, 10.5, -10.5))
[email protected]('skip_nulls', (True, False))
+def test_cumulative_sum(start, skip_nulls):
+ # Exact tests (e.g., integral types)
+ start_int = int(start)
+ starts = [start_int, pa.scalar(start_int, type=pa.int8()),
+ pa.scalar(start_int, type=pa.int64())]
+ for strt in starts:
+ arrays = [
+ pa.array([1, 2, 3]),
+ pa.array([0, None, 20, 30]),
+ pa.chunked_array([[0, None], [20, 30]])
+ ]
+ expected_arrays = [
+ pa.array([1, 3, 6]),
+ pa.array([0, None, 20, 50])
+ if skip_nulls else pa.array([0, None, None, None]),
+ pa.chunked_array([[0, None, 20, 50]])
+ if skip_nulls else pa.chunked_array([[0, None, None, None]])
+ ]
+ for i, arr in enumerate(arrays):
+ result = pc.cumulative_sum(arr, start=strt, skip_nulls=skip_nulls)
+ # Add `start` offset to expected array before comparing
+ expected = pc.add(expected_arrays[i], strt)
+ assert result.equals(expected)
+
+ starts = [start, pa.scalar(start, type=pa.float32()),
+ pa.scalar(start, type=pa.float64())]
+ for strt in starts:
+ arrays = [
+ pa.array([1.125, 2.25, 3.03125]),
+ pa.array([1, np.nan, 2, -3, 4, 5]),
+ pa.array([1, np.nan, None, 3, None, 5])
+ ]
+ expected_arrays = [
+ np.array([1.125, 3.375, 6.40625]),
+ np.array([1, np.nan, np.nan, np.nan, np.nan, np.nan]),
+ np.array([1, np.nan, None, np.nan, None, np.nan])
+ if skip_nulls else np.array([1, np.nan, None, None, None, None])
+ ]
+ for i, arr in enumerate(arrays):
+ result = pc.cumulative_sum(arr, start=strt, skip_nulls=skip_nulls)
+ # Add `start` offset to expected array before comparing
+ expected = pc.add(expected_arrays[i], strt)
+ np.testing.assert_array_almost_equal(result.to_numpy(
+ zero_copy_only=False), expected.to_numpy(zero_copy_only=False))
+
+ for strt in ['a', pa.scalar('arrow'), 1.1]:
+ with pytest.raises(pa.ArrowInvalid):
+ pc.cumulative_sum([1, 2, 3], start=strt)
+
+
def test_make_struct():
assert pc.make_struct(1, 'a').as_py() == {'0': 1, '1': 'a'}