Repository: arrow Updated Branches: refs/heads/master a9a80fef7 -> b0b125fd7
ARROW-229: [C++] Implement cast functions for numeric types, booleans Implements safe and unsafe casts amongst booleans and signed/unsigned integers, and single and double precision floating point numbers. Currently there is only the option to check for integer overflows when casting from a larger integer to a smaller integer. This API should be regarded as experimental in 0.7.0. There are a number of follow up patches we'll want to do quickly after this (exposing this in Python, incorporating into Array.from_pandas) Author: Wes McKinney <[email protected]> Closes #1027 from wesm/ARROW-229 and squashes the following commits: 82fea97 [Wes McKinney] Fix MSVC warning ead4a95 [Wes McKinney] Fix overflow check where overflow occurs in a null slot dc7f8d9 [Wes McKinney] Some basic smoke tests to validate implemented casts 879653d [Wes McKinney] Start test suite for Cast 22308ba [Wes McKinney] Implement cast kernels for numbers. Add helper type traits ca1c813 [Wes McKinney] Work on context d05c274 [Wes McKinney] Start some prototyping of a cast implementation Project: http://git-wip-us.apache.org/repos/asf/arrow/repo Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/b0b125fd Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/b0b125fd Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/b0b125fd Branch: refs/heads/master Commit: b0b125fd74b2bb334e90d9775a670bf18ffd8a22 Parents: a9a80fe Author: Wes McKinney <[email protected]> Authored: Thu Sep 7 13:00:37 2017 -0400 Committer: Wes McKinney <[email protected]> Committed: Thu Sep 7 13:00:37 2017 -0400 ---------------------------------------------------------------------- cpp/CMakeLists.txt | 37 +++- cpp/src/arrow/array.h | 2 +- cpp/src/arrow/compute/CMakeLists.txt | 28 +++ cpp/src/arrow/compute/cast.cc | 329 +++++++++++++++++++++++++++++ cpp/src/arrow/compute/cast.h | 55 +++++ cpp/src/arrow/compute/compute-test.cc | 315 +++++++++++++++++++++++++++ cpp/src/arrow/compute/context.cc | 46 ++++ cpp/src/arrow/compute/context.h | 68 ++++++ cpp/src/arrow/memory_pool.cc | 4 +- cpp/src/arrow/test-util.h | 10 + cpp/src/arrow/type.h | 82 +------ cpp/src/arrow/type_traits.h | 68 ++++++ 12 files changed, 956 insertions(+), 88 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 9f9d71b..24735ac 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -94,6 +94,10 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}") "Exclude deprecated APIs from build" OFF) + option(ARROW_COMPUTE + "Build the Arrow Compute Modules" + ON) + option(ARROW_EXTRA_ERROR_CONTEXT "Compile with extra error context (line numbers, code)" OFF) @@ -727,17 +731,6 @@ endif() add_subdirectory(src/arrow) add_subdirectory(src/arrow/io) -if (ARROW_GPU) - # IPC extensions required to build the GPU library - set(ARROW_IPC ON) - add_subdirectory(src/arrow/gpu) -endif() - -if (ARROW_IPC) - add_subdirectory(src/arrow/ipc) - add_dependencies(arrow_dependencies metadata_fbs) -endif() - set(ARROW_SRCS src/arrow/array.cc src/arrow/buffer.cc @@ -751,6 +744,9 @@ set(ARROW_SRCS src/arrow/type.cc src/arrow/visitor.cc + src/arrow/compute/cast.cc + src/arrow/compute/context.cc + src/arrow/io/file.cc src/arrow/io/interfaces.cc src/arrow/io/memory.cc @@ -763,6 +759,25 @@ set(ARROW_SRCS src/arrow/util/key_value_metadata.cc ) +if (ARROW_COMPUTE) + add_subdirectory(src/arrow/compute) + set(ARROW_SRCS ${ARROW_SRCS} + src/arrow/compute/cast.cc + src/arrow/compute/context.cc + ) +endif() + +if (ARROW_GPU) + # IPC extensions required to build the GPU library + set(ARROW_IPC ON) + add_subdirectory(src/arrow/gpu) +endif() + +if (ARROW_IPC) + add_subdirectory(src/arrow/ipc) + add_dependencies(arrow_dependencies metadata_fbs) +endif() + if (ARROW_WITH_BROTLI) add_definitions(-DARROW_WITH_BROTLI) SET(ARROW_SRCS src/arrow/util/compression_brotli.cc ${ARROW_SRCS}) http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/array.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h index 8e965e8..61ab2ef 100644 --- a/cpp/src/arrow/array.h +++ b/cpp/src/arrow/array.h @@ -678,4 +678,4 @@ MakePrimitiveArray(const std::shared_ptr<DataType>& type, } // namespace arrow -#endif +#endif // ARROW_ARRAY_H http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/compute/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/CMakeLists.txt b/cpp/src/arrow/compute/CMakeLists.txt new file mode 100644 index 0000000..a154c47 --- /dev/null +++ b/cpp/src/arrow/compute/CMakeLists.txt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Headers: top level +install(FILES + cast.h + context.h + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/arrow/compute") + +####################################### +# Unit tests +####################################### + +ADD_ARROW_TEST(compute-test) http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/compute/cast.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc new file mode 100644 index 0000000..f610f6b --- /dev/null +++ b/cpp/src/arrow/compute/cast.cc @@ -0,0 +1,329 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/cast.h" + +#include <cstdint> +#include <functional> +#include <limits> +#include <memory> +#include <sstream> +#include <type_traits> + +#include "arrow/type_traits.h" +#include "arrow/util/logging.h" + +#include "arrow/compute/context.h" + +namespace arrow { +namespace compute { + +struct CastContext { + FunctionContext* func_ctx; + CastOptions options; +}; + +typedef std::function<void(CastContext*, const ArrayData&, ArrayData*)> CastFunction; + +template <typename OutType, typename InType, typename Enable = void> +struct CastFunctor {}; + +// Type is the same, no computation required +template <typename O, typename I> +struct CastFunctor<O, I, typename std::enable_if<std::is_same<I, O>::value>::type> { + void operator()(CastContext* ctx, const ArrayData& input, ArrayData* output) { + output->type = input.type; + output->buffers = input.buffers; + output->length = input.length; + output->offset = input.offset; + output->null_count = input.null_count; + output->child_data = input.child_data; + } +}; + +// ---------------------------------------------------------------------- +// Null to other things + +template <typename T> +struct CastFunctor<T, NullType, + typename std::enable_if<!std::is_same<T, NullType>::value>::type> { + void operator()(CastContext* ctx, const ArrayData& input, ArrayData* output) { + ctx->func_ctx->SetStatus(Status::NotImplemented("NullType")); + } +}; + +// ---------------------------------------------------------------------- +// Boolean to other things + +// Cast from Boolean to other numbers +template <typename T> +struct CastFunctor<T, BooleanType, + typename std::enable_if<std::is_base_of<Number, T>::value>::type> { + void operator()(CastContext* ctx, const ArrayData& input, ArrayData* output) { + using c_type = typename T::c_type; + const uint8_t* data = input.buffers[1]->data(); + auto out = reinterpret_cast<c_type*>(output->buffers[1]->mutable_data()); + constexpr auto kOne = static_cast<c_type>(1); + constexpr auto kZero = static_cast<c_type>(0); + for (int64_t i = 0; i < input.length; ++i) { + *out++ = BitUtil::GetBit(data, i) ? kOne : kZero; + } + } +}; + +// ---------------------------------------------------------------------- +// Integers and Floating Point + +template <typename O, typename I> +struct is_numeric_cast { + static constexpr bool value = + (std::is_base_of<Number, O>::value && std::is_base_of<Number, I>::value) && + (!std::is_same<O, I>::value); +}; + +template <typename O, typename I, typename Enable = void> +struct is_integer_downcast { + static constexpr bool value = false; +}; + +template <typename O, typename I> +struct is_integer_downcast< + O, I, typename std::enable_if<std::is_base_of<Integer, O>::value && + std::is_base_of<Integer, I>::value>::type> { + using O_T = typename O::c_type; + using I_T = typename I::c_type; + + static constexpr bool value = + ((!std::is_same<O, I>::value) && + + // same size, but unsigned to signed + ((sizeof(O_T) == sizeof(I_T) && std::is_signed<O_T>::value && + std::is_unsigned<I_T>::value) || + + // Smaller output size + (sizeof(O_T) < sizeof(I_T)))); +}; + +template <typename O, typename I> +struct CastFunctor<O, I, typename std::enable_if<std::is_same<BooleanType, O>::value && + std::is_base_of<Number, I>::value && + !std::is_same<O, I>::value>::type> { + void operator()(CastContext* ctx, const ArrayData& input, ArrayData* output) { + using in_type = typename I::c_type; + auto in_data = reinterpret_cast<const in_type*>(input.buffers[1]->data()); + uint8_t* out_data = reinterpret_cast<uint8_t*>(output->buffers[1]->mutable_data()); + for (int64_t i = 0; i < input.length; ++i) { + BitUtil::SetBitTo(out_data, i, (*in_data++) != 0); + } + } +}; + +template <typename O, typename I> +struct CastFunctor<O, I, + typename std::enable_if<is_integer_downcast<O, I>::value>::type> { + void operator()(CastContext* ctx, const ArrayData& input, ArrayData* output) { + using in_type = typename I::c_type; + using out_type = typename O::c_type; + + auto in_offset = input.offset; + + auto in_data = reinterpret_cast<const in_type*>(input.buffers[1]->data()) + in_offset; + auto out_data = reinterpret_cast<out_type*>(output->buffers[1]->mutable_data()); + + if (!ctx->options.allow_int_overflow) { + constexpr in_type kMax = static_cast<in_type>(std::numeric_limits<out_type>::max()); + constexpr in_type kMin = static_cast<in_type>(std::numeric_limits<out_type>::min()); + + if (input.null_count > 0) { + const uint8_t* is_valid = input.buffers[0]->data(); + int64_t is_valid_offset = in_offset; + for (int64_t i = 0; i < input.length; ++i) { + if (ARROW_PREDICT_FALSE(BitUtil::GetBit(is_valid, is_valid_offset++) && + (*in_data > kMax || *in_data < kMin))) { + ctx->func_ctx->SetStatus(Status::Invalid("Integer value out of bounds")); + } + *out_data++ = static_cast<out_type>(*in_data++); + } + } else { + for (int64_t i = 0; i < input.length; ++i) { + if (ARROW_PREDICT_FALSE(*in_data > kMax || *in_data < kMin)) { + ctx->func_ctx->SetStatus(Status::Invalid("Integer value out of bounds")); + } + *out_data++ = static_cast<out_type>(*in_data++); + } + } + } else { + for (int64_t i = 0; i < input.length; ++i) { + *out_data++ = static_cast<out_type>(*in_data++); + } + } + } +}; + +template <typename O, typename I> +struct CastFunctor<O, I, + typename std::enable_if<is_numeric_cast<O, I>::value && + !is_integer_downcast<O, I>::value>::type> { + void operator()(CastContext* ctx, const ArrayData& input, ArrayData* output) { + using in_type = typename I::c_type; + using out_type = typename O::c_type; + + auto in_data = reinterpret_cast<const in_type*>(input.buffers[1]->data()); + auto out_data = reinterpret_cast<out_type*>(output->buffers[1]->mutable_data()); + for (int64_t i = 0; i < input.length; ++i) { + *out_data++ = static_cast<out_type>(*in_data++); + } + } +}; + +// ---------------------------------------------------------------------- + +#define CAST_CASE(InType, OutType) \ + case OutType::type_id: \ + return [type](CastContext* ctx, const ArrayData& input, ArrayData* out) { \ + CastFunctor<OutType, InType> func; \ + func(ctx, input, out); \ + } + +#define NUMERIC_CASES(FN, IN_TYPE) \ + FN(IN_TYPE, BooleanType); \ + FN(IN_TYPE, UInt8Type); \ + FN(IN_TYPE, Int8Type); \ + FN(IN_TYPE, UInt16Type); \ + FN(IN_TYPE, Int16Type); \ + FN(IN_TYPE, UInt32Type); \ + FN(IN_TYPE, Int32Type); \ + FN(IN_TYPE, UInt64Type); \ + FN(IN_TYPE, Int64Type); \ + FN(IN_TYPE, FloatType); \ + FN(IN_TYPE, DoubleType); + +#define GET_CAST_FUNCTION(CapType) \ + static CastFunction Get##CapType##CastFunc(const std::shared_ptr<DataType>& type) { \ + switch (type->id()) { \ + NUMERIC_CASES(CAST_CASE, CapType); \ + default: \ + break; \ + } \ + return nullptr; \ + } + +#define CAST_FUNCTION_CASE(CapType) \ + case CapType::type_id: \ + *out = Get##CapType##CastFunc(out_type); \ + break + +GET_CAST_FUNCTION(BooleanType); +GET_CAST_FUNCTION(UInt8Type); +GET_CAST_FUNCTION(Int8Type); +GET_CAST_FUNCTION(UInt16Type); +GET_CAST_FUNCTION(Int16Type); +GET_CAST_FUNCTION(UInt32Type); +GET_CAST_FUNCTION(Int32Type); +GET_CAST_FUNCTION(UInt64Type); +GET_CAST_FUNCTION(Int64Type); +GET_CAST_FUNCTION(FloatType); +GET_CAST_FUNCTION(DoubleType); + +static Status GetCastFunction(const DataType& in_type, + const std::shared_ptr<DataType>& out_type, + CastFunction* out) { + switch (in_type.id()) { + CAST_FUNCTION_CASE(BooleanType); + CAST_FUNCTION_CASE(UInt8Type); + CAST_FUNCTION_CASE(Int8Type); + CAST_FUNCTION_CASE(UInt16Type); + CAST_FUNCTION_CASE(Int16Type); + CAST_FUNCTION_CASE(UInt32Type); + CAST_FUNCTION_CASE(Int32Type); + CAST_FUNCTION_CASE(UInt64Type); + CAST_FUNCTION_CASE(Int64Type); + CAST_FUNCTION_CASE(FloatType); + CAST_FUNCTION_CASE(DoubleType); + default: + break; + } + if (*out == nullptr) { + std::stringstream ss; + ss << "No cast implemented from " << in_type.ToString() << " to " + << out_type->ToString(); + return Status::NotImplemented(ss.str()); + } + return Status::OK(); +} + +static Status AllocateLike(FunctionContext* ctx, const Array& array, + const std::shared_ptr<DataType>& out_type, + std::shared_ptr<ArrayData>* out) { + if (!is_primitive(out_type->id())) { + return Status::NotImplemented(out_type->ToString()); + } + + const auto& fw_type = static_cast<const FixedWidthType&>(*out_type); + + auto result = std::make_shared<ArrayData>(); + result->type = out_type; + result->length = array.length(); + result->offset = 0; + result->null_count = array.null_count(); + + // Propagate null bitmap + // TODO(wesm): handling null bitmap when input type is NullType + result->buffers.push_back(array.data()->buffers[0]); + + std::shared_ptr<Buffer> out_data; + + int bit_width = fw_type.bit_width(); + + if (bit_width == 1) { + RETURN_NOT_OK(ctx->Allocate(BitUtil::BytesForBits(array.length()), &out_data)); + } else if (bit_width % 8 == 0) { + RETURN_NOT_OK(ctx->Allocate(array.length() * fw_type.bit_width() / 8, &out_data)); + } else { + DCHECK(false); + } + result->buffers.push_back(out_data); + + *out = result; + return Status::OK(); +} + +static Status Cast(CastContext* cast_ctx, const Array& array, + const std::shared_ptr<DataType>& out_type, + std::shared_ptr<Array>* out) { + // Dynamic dispatch to obtain right cast function + CastFunction func; + RETURN_NOT_OK(GetCastFunction(*array.type(), out_type, &func)); + + // Allocate memory for output + std::shared_ptr<ArrayData> out_data; + RETURN_NOT_OK(AllocateLike(cast_ctx->func_ctx, array, out_type, &out_data)); + + func(cast_ctx, *array.data(), out_data.get()); + RETURN_IF_ERROR(cast_ctx->func_ctx); + return internal::MakeArray(out_data, out); +} + +Status Cast(FunctionContext* ctx, const Array& array, + const std::shared_ptr<DataType>& out_type, const CastOptions& options, + std::shared_ptr<Array>* out) { + CastContext cast_ctx{ctx, options}; + return Cast(&cast_ctx, array, out_type, out); +} + +} // namespace compute +} // namespace arrow http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/compute/cast.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h new file mode 100644 index 0000000..9ca70aa --- /dev/null +++ b/cpp/src/arrow/compute/cast.h @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef ARROW_COMPUTE_CAST_H +#define ARROW_COMPUTE_CAST_H + +#include <memory> + +#include "arrow/array.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +using internal::ArrayData; + +namespace compute { + +class FunctionContext; + +struct CastOptions { + bool allow_int_overflow; +}; + +/// \brief Cast from one array type to another +/// \param[in] context +/// \param[in] array +/// \param[in] to_type +/// \param[in] options +/// \param[out] out +/// +/// \since 0.7.0 +/// \note API not yet finalized +ARROW_EXPORT +Status Cast(FunctionContext* context, const Array& array, + const std::shared_ptr<DataType>& to_type, const CastOptions& options, + std::shared_ptr<Array>* out); + +} // namespace compute +} // namespace arrow + +#endif // ARROW_COMPUTE_CAST_H http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/compute/compute-test.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/compute-test.cc b/cpp/src/arrow/compute/compute-test.cc new file mode 100644 index 0000000..cda5755 --- /dev/null +++ b/cpp/src/arrow/compute/compute-test.cc @@ -0,0 +1,315 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <cstdint> +#include <cstdlib> +#include <memory> +#include <numeric> +#include <sstream> +#include <vector> + +#include "gtest/gtest.h" + +#include "arrow/array.h" +#include "arrow/buffer.h" +#include "arrow/builder.h" +#include "arrow/compare.h" +#include "arrow/ipc/test-common.h" +#include "arrow/memory_pool.h" +#include "arrow/pretty_print.h" +#include "arrow/status.h" +#include "arrow/test-common.h" +#include "arrow/test-util.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" + +#include "arrow/compute/cast.h" +#include "arrow/compute/context.h" + +using std::vector; + +namespace arrow { +namespace compute { + +void AssertArraysEqual(const Array& left, const Array& right) { + bool are_equal = false; + ASSERT_OK(ArrayEquals(left, right, &are_equal)); + + if (!are_equal) { + std::stringstream ss; + + ss << "Left: "; + EXPECT_OK(PrettyPrint(left, 0, &ss)); + ss << "\nRight: "; + EXPECT_OK(PrettyPrint(right, 0, &ss)); + FAIL() << ss.str(); + } +} + +class ComputeFixture { + public: + ComputeFixture() : pool_(default_memory_pool()), ctx_(pool_) {} + + protected: + MemoryPool* pool_; + FunctionContext ctx_; +}; + +// ---------------------------------------------------------------------- +// Cast + +class TestCast : public ComputeFixture, public ::testing::Test { + public: + void CheckPass(const Array& input, const Array& expected, + const std::shared_ptr<DataType>& out_type, const CastOptions& options) { + std::shared_ptr<Array> result; + ASSERT_OK(Cast(&ctx_, input, out_type, options, &result)); + AssertArraysEqual(expected, *result); + } + + template <typename InType, typename I_TYPE> + void CheckFails(const std::shared_ptr<DataType>& in_type, + const std::vector<I_TYPE>& in_values, const std::vector<bool>& is_valid, + const std::shared_ptr<DataType>& out_type, const CastOptions& options) { + std::shared_ptr<Array> input, result; + if (is_valid.size() > 0) { + ArrayFromVector<InType, I_TYPE>(in_type, is_valid, in_values, &input); + } else { + ArrayFromVector<InType, I_TYPE>(in_type, in_values, &input); + } + ASSERT_RAISES(Invalid, Cast(&ctx_, *input, out_type, options, &result)); + } + + template <typename InType, typename I_TYPE, typename OutType, typename O_TYPE> + void CheckCase(const std::shared_ptr<DataType>& in_type, + const std::vector<I_TYPE>& in_values, const std::vector<bool>& is_valid, + const std::shared_ptr<DataType>& out_type, + const std::vector<O_TYPE>& out_values, const CastOptions& options) { + std::shared_ptr<Array> input, expected; + if (is_valid.size() > 0) { + ArrayFromVector<InType, I_TYPE>(in_type, is_valid, in_values, &input); + ArrayFromVector<OutType, O_TYPE>(out_type, is_valid, out_values, &expected); + } else { + ArrayFromVector<InType, I_TYPE>(in_type, in_values, &input); + ArrayFromVector<OutType, O_TYPE>(out_type, out_values, &expected); + } + CheckPass(*input, *expected, out_type, options); + } +}; + +TEST_F(TestCast, SameTypeZeroCopy) { + vector<bool> is_valid = {true, false, true, true, true}; + vector<int32_t> v1 = {0, 1, 2, 3, 4}; + + std::shared_ptr<Array> arr; + ArrayFromVector<Int32Type, int32_t>(int32(), is_valid, v1, &arr); + + std::shared_ptr<Array> result; + ASSERT_OK(Cast(&this->ctx_, *arr, int32(), {}, &result)); + + const auto& lbuffers = arr->data()->buffers; + const auto& rbuffers = result->data()->buffers; + + // Buffers are the same + ASSERT_EQ(lbuffers[0].get(), rbuffers[0].get()); + ASSERT_EQ(lbuffers[1].get(), rbuffers[1].get()); +} + +TEST_F(TestCast, ToBoolean) { + CastOptions options; + + vector<bool> is_valid = {true, false, true, true, true}; + + // int8, should suffice for other integers + vector<int8_t> v1 = {0, 1, 127, -1, 0}; + vector<bool> e1 = {false, true, true, true, false}; + CheckCase<Int8Type, int8_t, BooleanType, bool>(int8(), v1, is_valid, boolean(), e1, + options); + + // floating point + vector<double> v2 = {1.0, 0, 0, -1.0, 5.0}; + vector<bool> e2 = {true, false, false, true, true}; + CheckCase<DoubleType, double, BooleanType, bool>(float64(), v2, is_valid, boolean(), e2, + options); +} + +TEST_F(TestCast, ToIntUpcast) { + CastOptions options; + options.allow_int_overflow = false; + + vector<bool> is_valid = {true, false, true, true, true}; + + // int8 to int32 + vector<int8_t> v1 = {0, 1, 127, -1, 0}; + vector<int32_t> e1 = {0, 1, 127, -1, 0}; + CheckCase<Int8Type, int8_t, Int32Type, int32_t>(int8(), v1, is_valid, int32(), e1, + options); + + // bool to int8 + vector<bool> v2 = {false, true, false, true, true}; + vector<int8_t> e2 = {0, 1, 0, 1, 1}; + CheckCase<BooleanType, bool, Int8Type, int8_t>(boolean(), v2, is_valid, int8(), e2, + options); + + // uint8 to int16, no overflow/underrun + vector<uint8_t> v3 = {0, 100, 200, 255, 0}; + vector<int16_t> e3 = {0, 100, 200, 255, 0}; + CheckCase<UInt8Type, uint8_t, Int16Type, int16_t>(uint8(), v3, is_valid, int16(), e3, + options); + + // floating point to integer + vector<double> v4 = {1.5, 0, 0.5, -1.5, 5.5}; + vector<int32_t> e4 = {1, 0, 0, -1, 5}; + CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v4, is_valid, int32(), e4, + options); +} + +TEST_F(TestCast, OverflowInNullSlot) { + CastOptions options; + options.allow_int_overflow = false; + + vector<bool> is_valid = {true, false, true, true, true}; + + vector<int32_t> v11 = {0, 70000, 2000, 1000, 0}; + vector<int16_t> e11 = {0, 0, 2000, 1000, 0}; + + std::shared_ptr<Array> expected; + ArrayFromVector<Int16Type, int16_t>(int16(), is_valid, e11, &expected); + + auto buf = std::make_shared<Buffer>(reinterpret_cast<const uint8_t*>(v11.data()), + static_cast<int64_t>(v11.size())); + Int32Array tmp11(5, buf, expected->null_bitmap(), -1); + + CheckPass(tmp11, *expected, int16(), options); +} + +TEST_F(TestCast, ToIntDowncastSafe) { + CastOptions options; + options.allow_int_overflow = false; + + vector<bool> is_valid = {true, false, true, true, true}; + + // int16 to uint8, no overflow/underrun + vector<int16_t> v5 = {0, 100, 200, 1, 2}; + vector<uint8_t> e5 = {0, 100, 200, 1, 2}; + CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v5, is_valid, uint8(), e5, + options); + + // int16 to uint8, with overflow + vector<int16_t> v6 = {0, 100, 256, 0, 0}; + CheckFails<Int16Type>(int16(), v6, is_valid, uint8(), options); + + // underflow + vector<int16_t> v7 = {0, 100, -1, 0, 0}; + CheckFails<Int16Type>(int16(), v7, is_valid, uint8(), options); + + // int32 to int16, no overflow + vector<int32_t> v8 = {0, 1000, 2000, 1, 2}; + vector<int16_t> e8 = {0, 1000, 2000, 1, 2}; + CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v8, is_valid, int16(), e8, + options); + + // int32 to int16, overflow + vector<int32_t> v9 = {0, 1000, 2000, 70000, 0}; + CheckFails<Int32Type>(int32(), v9, is_valid, int16(), options); + + // underflow + vector<int32_t> v10 = {0, 1000, 2000, -70000, 0}; + CheckFails<Int32Type>(int32(), v9, is_valid, int16(), options); +} + +TEST_F(TestCast, ToIntDowncastUnsafe) { + CastOptions options; + options.allow_int_overflow = true; + + vector<bool> is_valid = {true, false, true, true, true}; + + // int16 to uint8, no overflow/underrun + vector<int16_t> v5 = {0, 100, 200, 1, 2}; + vector<uint8_t> e5 = {0, 100, 200, 1, 2}; + CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v5, is_valid, uint8(), e5, + options); + + // int16 to uint8, with overflow + vector<int16_t> v6 = {0, 100, 256, 0, 0}; + vector<uint8_t> e6 = {0, 100, 0, 0, 0}; + CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v6, is_valid, uint8(), e6, + options); + + // underflow + vector<int16_t> v7 = {0, 100, -1, 0, 0}; + vector<uint8_t> e7 = {0, 100, 255, 0, 0}; + CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v7, is_valid, uint8(), e7, + options); + + // int32 to int16, no overflow + vector<int32_t> v8 = {0, 1000, 2000, 1, 2}; + vector<int16_t> e8 = {0, 1000, 2000, 1, 2}; + CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v8, is_valid, int16(), e8, + options); + + // int32 to int16, overflow + // TODO(wesm): do we want to allow this? we could set to null + vector<int32_t> v9 = {0, 1000, 2000, 70000, 0}; + vector<int16_t> e9 = {0, 1000, 2000, 4464, 0}; + CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v9, is_valid, int16(), e9, + options); + + // underflow + // TODO(wesm): do we want to allow this? we could set overflow to null + vector<int32_t> v10 = {0, 1000, 2000, -70000, 0}; + vector<int16_t> e10 = {0, 1000, 2000, -4464, 0}; + CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v10, is_valid, int16(), e10, + options); +} + +TEST_F(TestCast, ToDouble) { + CastOptions options; + vector<bool> is_valid = {true, false, true, true, true}; + + // int16 to double + vector<int16_t> v1 = {0, 100, 200, 1, 2}; + vector<double> e1 = {0, 100, 200, 1, 2}; + CheckCase<Int16Type, int16_t, DoubleType, double>(int16(), v1, is_valid, float64(), e1, + options); + + // float to double + vector<float> v2 = {0, 100, 200, 1, 2}; + vector<double> e2 = {0, 100, 200, 1, 2}; + CheckCase<FloatType, float, DoubleType, double>(float32(), v2, is_valid, float64(), e2, + options); + + // bool to double + vector<bool> v3 = {true, true, false, false, true}; + vector<double> e3 = {1, 1, 0, 0, 1}; + CheckCase<BooleanType, bool, DoubleType, double>(boolean(), v3, is_valid, float64(), e3, + options); +} + +TEST_F(TestCast, UnsupportedTarget) { + vector<bool> is_valid = {true, false, true, true, true}; + vector<int32_t> v1 = {0, 1, 2, 3, 4}; + + std::shared_ptr<Array> arr; + ArrayFromVector<Int32Type, int32_t>(int32(), is_valid, v1, &arr); + + std::shared_ptr<Array> result; + ASSERT_RAISES(NotImplemented, Cast(&this->ctx_, *arr, utf8(), {}, &result)); +} + +} // namespace compute +} // namespace arrow http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/compute/context.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/context.cc b/cpp/src/arrow/compute/context.cc new file mode 100644 index 0000000..792dc4f --- /dev/null +++ b/cpp/src/arrow/compute/context.cc @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/context.h" + +#include <memory> + +#include "arrow/buffer.h" + +namespace arrow { +namespace compute { + +FunctionContext::FunctionContext(MemoryPool* pool) : pool_(pool) {} + +MemoryPool* FunctionContext::memory_pool() const { return pool_; } + +Status FunctionContext::Allocate(const int64_t nbytes, std::shared_ptr<Buffer>* out) { + return AllocateBuffer(pool_, nbytes, out); +} + +void FunctionContext::SetStatus(const Status& status) { + if (ARROW_PREDICT_FALSE(!status_.ok())) { + return; + } + status_ = status; +} + +/// \brief Clear any error status +void FunctionContext::ResetStatus() { status_ = Status::OK(); } + +} // namespace compute +} // namespace arrow http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/compute/context.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/context.h b/cpp/src/arrow/compute/context.h new file mode 100644 index 0000000..caff2e2 --- /dev/null +++ b/cpp/src/arrow/compute/context.h @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef ARROW_COMPUTE_CONTEXT_H +#define ARROW_COMPUTE_CONTEXT_H + +#include "arrow/status.h" +#include "arrow/type_fwd.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { + +#define RETURN_IF_ERROR(ctx) \ + if (ARROW_PREDICT_FALSE(ctx->HasError())) { \ + Status s = ctx->status(); \ + ctx->ResetStatus(); \ + return s; \ + } + +/// \brief Container for variables and options used by function evaluation +class ARROW_EXPORT FunctionContext { + public: + explicit FunctionContext(MemoryPool* pool); + MemoryPool* memory_pool() const; + + /// \brief Allocate buffer from the context's memory pool + Status Allocate(const int64_t nbytes, std::shared_ptr<Buffer>* out); + + /// \brief Indicate that an error has occurred, to be checked by a parent caller + /// \param[in] status a Status instance + /// + /// \note Will not overwrite a prior set Status, so we will have the first + /// error that occurred until FunctionContext::ResetStatus is called + void SetStatus(const Status& status); + + /// \brief Clear any error status + void ResetStatus(); + + /// \brief Return true if an error has occurred + bool HasError() const { return !status_.ok(); } + + /// \brief Return the current status of the context + const Status& status() const { return status_; } + + private: + Status status_; + MemoryPool* pool_; +}; + +} // namespace compute +} // namespace arrow + +#endif // ARROW_COMPUTE_CONTEXT_H http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/memory_pool.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/memory_pool.cc b/cpp/src/arrow/memory_pool.cc index 7fd999e..d86fb08 100644 --- a/cpp/src/arrow/memory_pool.cc +++ b/cpp/src/arrow/memory_pool.cc @@ -168,8 +168,8 @@ Status LoggingMemoryPool::Allocate(int64_t size, uint8_t** out) { Status LoggingMemoryPool::Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) { Status s = pool_->Reallocate(old_size, new_size, ptr); - std::cout << "Reallocate: old_size = " << old_size - << " - new_size = " << new_size << std::endl; + std::cout << "Reallocate: old_size = " << old_size << " - new_size = " << new_size + << std::endl; return s; } http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/test-util.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/test-util.h b/cpp/src/arrow/test-util.h index 91f2bc9..22a933d 100644 --- a/cpp/src/arrow/test-util.h +++ b/cpp/src/arrow/test-util.h @@ -233,6 +233,16 @@ void ArrayFromVector(const std::shared_ptr<DataType>& type, } template <typename TYPE, typename C_TYPE> +void ArrayFromVector(const std::shared_ptr<DataType>& type, + const std::vector<C_TYPE>& values, std::shared_ptr<Array>* out) { + typename TypeTraits<TYPE>::BuilderType builder(type, default_memory_pool()); + for (size_t i = 0; i < values.size(); ++i) { + ASSERT_OK(builder.Append(values[i])); + } + ASSERT_OK(builder.Finish(out)); +} + +template <typename TYPE, typename C_TYPE> void ArrayFromVector(const std::vector<bool>& is_valid, const std::vector<C_TYPE>& values, std::shared_ptr<Array>* out) { typename TypeTraits<TYPE>::BuilderType builder; http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/type.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index b532cd2..283e27e 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -186,15 +186,20 @@ class ARROW_EXPORT PrimitiveCType : public FixedWidthType { using FixedWidthType::FixedWidthType; }; -class ARROW_EXPORT Integer : public PrimitiveCType { +class ARROW_EXPORT Number : public PrimitiveCType { public: using PrimitiveCType::PrimitiveCType; +}; + +class ARROW_EXPORT Integer : public Number { + public: + using Number::Number; virtual bool is_signed() const = 0; }; -class ARROW_EXPORT FloatingPoint : public PrimitiveCType { +class ARROW_EXPORT FloatingPoint : public Number { public: - using PrimitiveCType::PrimitiveCType; + using Number::Number; enum Precision { HALF, SINGLE, DOUBLE }; virtual Precision precision() const = 0; }; @@ -842,77 +847,6 @@ std::shared_ptr<Schema> schema( std::vector<std::shared_ptr<Field>>&& fields, const std::shared_ptr<const KeyValueMetadata>& metadata = nullptr); -// ---------------------------------------------------------------------- -// - -static inline bool is_integer(Type::type type_id) { - switch (type_id) { - case Type::UINT8: - case Type::INT8: - case Type::UINT16: - case Type::INT16: - case Type::UINT32: - case Type::INT32: - case Type::UINT64: - case Type::INT64: - return true; - default: - break; - } - return false; -} - -static inline bool is_floating(Type::type type_id) { - switch (type_id) { - case Type::HALF_FLOAT: - case Type::FLOAT: - case Type::DOUBLE: - return true; - default: - break; - } - return false; -} - -static inline bool is_primitive(Type::type type_id) { - switch (type_id) { - case Type::NA: - case Type::BOOL: - case Type::UINT8: - case Type::INT8: - case Type::UINT16: - case Type::INT16: - case Type::UINT32: - case Type::INT32: - case Type::UINT64: - case Type::INT64: - case Type::HALF_FLOAT: - case Type::FLOAT: - case Type::DOUBLE: - case Type::DATE32: - case Type::DATE64: - case Type::TIME32: - case Type::TIME64: - case Type::TIMESTAMP: - case Type::INTERVAL: - return true; - default: - break; - } - return false; -} - -static inline bool is_binary_like(Type::type type_id) { - switch (type_id) { - case Type::BINARY: - case Type::STRING: - return true; - default: - break; - } - return false; -} - } // namespace arrow #endif // ARROW_TYPE_H http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/type_traits.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index d424cc8..fbd7839 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -362,6 +362,74 @@ struct IsNumeric { static constexpr bool value = std::is_arithmetic<c_type>::value; }; +static inline bool is_integer(Type::type type_id) { + switch (type_id) { + case Type::UINT8: + case Type::INT8: + case Type::UINT16: + case Type::INT16: + case Type::UINT32: + case Type::INT32: + case Type::UINT64: + case Type::INT64: + return true; + default: + break; + } + return false; +} + +static inline bool is_floating(Type::type type_id) { + switch (type_id) { + case Type::HALF_FLOAT: + case Type::FLOAT: + case Type::DOUBLE: + return true; + default: + break; + } + return false; +} + +static inline bool is_primitive(Type::type type_id) { + switch (type_id) { + case Type::NA: + case Type::BOOL: + case Type::UINT8: + case Type::INT8: + case Type::UINT16: + case Type::INT16: + case Type::UINT32: + case Type::INT32: + case Type::UINT64: + case Type::INT64: + case Type::HALF_FLOAT: + case Type::FLOAT: + case Type::DOUBLE: + case Type::DATE32: + case Type::DATE64: + case Type::TIME32: + case Type::TIME64: + case Type::TIMESTAMP: + case Type::INTERVAL: + return true; + default: + break; + } + return false; +} + +static inline bool is_binary_like(Type::type type_id) { + switch (type_id) { + case Type::BINARY: + case Type::STRING: + return true; + default: + break; + } + return false; +} + } // namespace arrow #endif // ARROW_TYPE_TRAITS_H
