[
https://issues.apache.org/jira/browse/ARROW-1484?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16219848#comment-16219848
]
ASF GitHub Bot commented on ARROW-1484:
---------------------------------------
wesm closed pull request #1245: ARROW-1484: [C++/Python] Implement casts
between date, time, timestamp units
URL: https://github.com/apache/arrow/pull/1245
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc
index e8bbfd347..68a2b1237 100644
--- a/cpp/src/arrow/compute/cast.cc
+++ b/cpp/src/arrow/compute/cast.cc
@@ -25,6 +25,7 @@
#include <sstream>
#include <string>
#include <type_traits>
+#include <utility>
#include "arrow/array.h"
#include "arrow/buffer.h"
@@ -68,6 +69,24 @@
namespace arrow {
namespace compute {
+template <typename T>
+inline const T* GetValuesAs(const ArrayData& data, int i) {
+ return reinterpret_cast<const T*>(data.buffers[i]->data()) + data.offset;
+}
+
+namespace {
+
+void CopyData(const Array& input, ArrayData* output) {
+ auto in_data = input.data();
+ output->length = in_data->length;
+ output->null_count = input.null_count();
+ output->buffers = in_data->buffers;
+ output->offset = in_data->offset;
+ output->child_data = in_data->child_data;
+}
+
+} // namespace
+
// ----------------------------------------------------------------------
// Zero copy casts
@@ -77,7 +96,9 @@ struct is_zero_copy_cast {
};
template <typename O, typename I>
-struct is_zero_copy_cast<O, I, typename std::enable_if<std::is_same<I,
O>::value>::type> {
+struct is_zero_copy_cast<
+ O, I, typename std::enable_if<std::is_same<I, O>::value &&
+ !std::is_base_of<ParametricType,
O>::value>::type> {
static constexpr bool value = true;
};
@@ -102,10 +123,7 @@ template <typename O, typename I>
struct CastFunctor<O, I, typename std::enable_if<is_zero_copy_cast<O,
I>::value>::type> {
void operator()(FunctionContext* ctx, const CastOptions& options, const
Array& input,
ArrayData* output) {
- auto in_data = input.data();
- output->null_count = input.null_count();
- output->buffers = in_data->buffers;
- output->child_data = in_data->child_data;
+ CopyData(input, output);
}
};
@@ -119,6 +137,7 @@ struct CastFunctor<T, NullType, typename std::enable_if<
ArrayData* output) {
// Simply initialize data to 0
auto buf = output->buffers[1];
+ DCHECK_EQ(output->offset, 0);
memset(buf->mutable_data(), 0, buf->size());
}
};
@@ -139,12 +158,16 @@ struct CastFunctor<T, BooleanType,
void operator()(FunctionContext* ctx, const CastOptions& options, const
Array& input,
ArrayData* output) {
using c_type = typename T::c_type;
- const uint8_t* data = input.data()->buffers[1]->data();
- auto out = reinterpret_cast<c_type*>(output->buffers[1]->mutable_data());
constexpr auto kOne = static_cast<c_type>(1);
constexpr auto kZero = static_cast<c_type>(0);
+
+ auto in_data = input.data();
+ internal::BitmapReader bit_reader(in_data->buffers[1]->data(),
in_data->offset,
+ in_data->length);
+ auto out = reinterpret_cast<c_type*>(output->buffers[1]->mutable_data());
for (int64_t i = 0; i < input.length(); ++i) {
- *out++ = BitUtil::GetBit(data, i) ? kOne : kZero;
+ *out++ = bit_reader.IsSet() ? kOne : kZero;
+ bit_reader.Next();
}
}
};
@@ -189,7 +212,9 @@ struct CastFunctor<O, I, typename
std::enable_if<std::is_same<BooleanType, O>::v
void operator()(FunctionContext* ctx, const CastOptions& options, const
Array& input,
ArrayData* output) {
using in_type = typename I::c_type;
- auto in_data = reinterpret_cast<const
in_type*>(input.data()->buffers[1]->data());
+ DCHECK_EQ(output->offset, 0);
+
+ const in_type* in_data = GetValuesAs<in_type>(*input.data(), 1);
uint8_t* out_data =
reinterpret_cast<uint8_t*>(output->buffers[1]->mutable_data());
for (int64_t i = 0; i < input.length(); ++i) {
BitUtil::SetBitTo(out_data, i, (*in_data++) != 0);
@@ -204,12 +229,11 @@ struct CastFunctor<O, I,
ArrayData* output) {
using in_type = typename I::c_type;
using out_type = typename O::c_type;
+ DCHECK_EQ(output->offset, 0);
auto in_offset = input.offset();
- const auto& input_buffers = input.data()->buffers;
-
- auto in_data = reinterpret_cast<const in_type*>(input_buffers[1]->data())
+ in_offset;
+ const in_type* in_data = GetValuesAs<in_type>(*input.data(), 1);
auto out_data =
reinterpret_cast<out_type*>(output->buffers[1]->mutable_data());
if (!options.allow_int_overflow) {
@@ -217,14 +241,15 @@ struct CastFunctor<O, I,
constexpr in_type kMin =
static_cast<in_type>(std::numeric_limits<out_type>::min());
if (input.null_count() > 0) {
- const uint8_t* is_valid = input_buffers[0]->data();
- int64_t is_valid_offset = in_offset;
+ internal::BitmapReader
is_valid_reader(input.data()->buffers[0]->data(),
+ in_offset, input.length());
for (int64_t i = 0; i < input.length(); ++i) {
- if (ARROW_PREDICT_FALSE(BitUtil::GetBit(is_valid, is_valid_offset++)
&&
+ if (ARROW_PREDICT_FALSE(is_valid_reader.IsSet() &&
(*in_data > kMax || *in_data < kMin))) {
ctx->SetStatus(Status::Invalid("Integer value out of bounds"));
}
*out_data++ = static_cast<out_type>(*in_data++);
+ is_valid_reader.Next();
}
} else {
for (int64_t i = 0; i < input.length(); ++i) {
@@ -251,7 +276,7 @@ struct CastFunctor<O, I,
using in_type = typename I::c_type;
using out_type = typename O::c_type;
- auto in_data = reinterpret_cast<const
in_type*>(input.data()->buffers[1]->data());
+ const in_type* in_data = GetValuesAs<in_type>(*input.data(), 1);
auto out_data =
reinterpret_cast<out_type*>(output->buffers[1]->mutable_data());
for (int64_t i = 0; i < input.length(); ++i) {
*out_data++ = static_cast<out_type>(*in_data++);
@@ -260,6 +285,125 @@ struct CastFunctor<O, I,
};
// ----------------------------------------------------------------------
+// From one timestamp to another
+
+template <typename in_type, typename out_type>
+inline void ShiftTime(FunctionContext* ctx, const CastOptions& options,
+ const bool is_multiply, const int64_t factor, const
Array& input,
+ ArrayData* output) {
+ const in_type* in_data = GetValuesAs<in_type>(*input.data(), 1);
+ auto out_data =
reinterpret_cast<out_type*>(output->buffers[1]->mutable_data());
+
+ if (is_multiply) {
+ for (int64_t i = 0; i < input.length(); i++) {
+ out_data[i] = static_cast<out_type>(in_data[i] * factor);
+ }
+ } else {
+ if (options.allow_time_truncate) {
+ for (int64_t i = 0; i < input.length(); i++) {
+ out_data[i] = static_cast<out_type>(in_data[i] / factor);
+ }
+ } else {
+ for (int64_t i = 0; i < input.length(); i++) {
+ out_data[i] = static_cast<out_type>(in_data[i] / factor);
+ if (input.IsValid(i) && (out_data[i] * factor != in_data[i])) {
+ std::stringstream ss;
+ ss << "Casting from " << input.type()->ToString() << " to "
+ << output->type->ToString() << " would lose data: " << in_data[i];
+ ctx->SetStatus(Status::Invalid(ss.str()));
+ break;
+ }
+ }
+ }
+ }
+}
+
+namespace {
+
+// {is_multiply, factor}
+const std::pair<bool, int64_t> kTimeConversionTable[4][4] = {
+ {{true, 1}, {true, 1000}, {true, 1000000}, {true, 1000000000L}}, //
SECOND
+ {{false, 1000}, {true, 1}, {true, 1000}, {true, 1000000}}, //
MILLI
+ {{false, 1000000}, {false, 1000}, {true, 1}, {true, 1000}}, //
MICRO
+ {{false, 1000000000L}, {false, 1000000}, {false, 1000}, {true, 1}}, //
NANO
+};
+
+} // namespace
+
+template <>
+struct CastFunctor<TimestampType, TimestampType> {
+ void operator()(FunctionContext* ctx, const CastOptions& options, const
Array& input,
+ ArrayData* output) {
+ // If units are the same, zero copy, otherwise convert
+ const auto& in_type = static_cast<const TimestampType&>(*input.type());
+ const auto& out_type = static_cast<const TimestampType&>(*output->type);
+
+ if (in_type.unit() == out_type.unit()) {
+ CopyData(input, output);
+ return;
+ }
+
+ std::pair<bool, int64_t> conversion =
+ kTimeConversionTable[static_cast<int>(in_type.unit())]
+ [static_cast<int>(out_type.unit())];
+
+ ShiftTime<int64_t, int64_t>(ctx, options, conversion.first,
conversion.second, input,
+ output);
+ }
+};
+
+// ----------------------------------------------------------------------
+// From one time32 or time64 to another
+
+template <typename O, typename I>
+struct CastFunctor<O, I,
+ typename std::enable_if<std::is_base_of<TimeType, I>::value
&&
+ std::is_base_of<TimeType,
O>::value>::type> {
+ void operator()(FunctionContext* ctx, const CastOptions& options, const
Array& input,
+ ArrayData* output) {
+ using in_t = typename I::c_type;
+ using out_t = typename O::c_type;
+
+ // If units are the same, zero copy, otherwise convert
+ const auto& in_type = static_cast<const I&>(*input.type());
+ const auto& out_type = static_cast<const O&>(*output->type);
+
+ if (in_type.unit() == out_type.unit()) {
+ CopyData(input, output);
+ return;
+ }
+
+ std::pair<bool, int64_t> conversion =
+ kTimeConversionTable[static_cast<int>(in_type.unit())]
+ [static_cast<int>(out_type.unit())];
+
+ ShiftTime<in_t, out_t>(ctx, options, conversion.first, conversion.second,
input,
+ output);
+ }
+};
+
+// ----------------------------------------------------------------------
+// Between date32 and date64
+
+constexpr int64_t kMillisecondsInDay = 86400000;
+
+template <>
+struct CastFunctor<Date64Type, Date32Type> {
+ void operator()(FunctionContext* ctx, const CastOptions& options, const
Array& input,
+ ArrayData* output) {
+ ShiftTime<int32_t, int64_t>(ctx, options, true, kMillisecondsInDay, input,
output);
+ }
+};
+
+template <>
+struct CastFunctor<Date32Type, Date64Type> {
+ void operator()(FunctionContext* ctx, const CastOptions& options, const
Array& input,
+ ArrayData* output) {
+ ShiftTime<int64_t, int32_t>(ctx, options, false, kMillisecondsInDay,
input, output);
+ }
+};
+
+// ----------------------------------------------------------------------
// Dictionary to other things
template <typename IndexType>
@@ -271,9 +415,8 @@ void UnpackFixedSizeBinaryDictionary(FunctionContext* ctx,
const Array& indices,
internal::BitmapReader valid_bits_reader(indices.null_bitmap_data(),
indices.offset(),
indices.length());
- const index_c_type* in =
- reinterpret_cast<const
index_c_type*>(indices.data()->buffers[1]->data()) +
- indices.offset();
+ const index_c_type* in = GetValuesAs<index_c_type>(*indices.data(), 1);
+
uint8_t* out = output->buffers[1]->mutable_data();
int32_t byte_width =
static_cast<const FixedSizeBinaryType&>(*output->type).byte_width();
@@ -336,9 +479,7 @@ Status UnpackBinaryDictionary(FunctionContext* ctx, const
Array& indices,
internal::BitmapReader valid_bits_reader(indices.null_bitmap_data(),
indices.offset(),
indices.length());
- const index_c_type* in =
- reinterpret_cast<const
index_c_type*>(indices.data()->buffers[1]->data()) +
- indices.offset();
+ const index_c_type* in = GetValuesAs<index_c_type>(*indices.data(), 1);
for (int64_t i = 0; i < indices.length(); ++i) {
if (valid_bits_reader.IsSet()) {
int32_t length;
@@ -409,9 +550,7 @@ void UnpackPrimitiveDictionary(const Array& indices, const
c_type* dictionary,
internal::BitmapReader valid_bits_reader(indices.null_bitmap_data(),
indices.offset(),
indices.length());
- const index_c_type* in =
- reinterpret_cast<const
index_c_type*>(indices.data()->buffers[1]->data()) +
- indices.offset();
+ const index_c_type* in = GetValuesAs<index_c_type>(*indices.data(), 1);
for (int64_t i = 0; i < indices.length(); ++i) {
if (valid_bits_reader.IsSet()) {
out[i] = dictionary[in[i]];
@@ -436,9 +575,8 @@ struct CastFunctor<T, DictionaryType,
DCHECK(values_type.Equals(*output->type))
<< "Dictionary type: " << values_type << " target type: " <<
(*output->type);
- auto dictionary =
- reinterpret_cast<const
c_type*>(type.dictionary()->data()->buffers[1]->data()) +
- type.dictionary()->offset();
+ const c_type* dictionary = GetValuesAs<c_type>(*type.dictionary()->data(),
1);
+
auto out = reinterpret_cast<c_type*>(output->buffers[1]->mutable_data());
const Array& indices = *dict_array.indices();
switch (indices.type()->id()) {
@@ -481,6 +619,9 @@ static Status AllocateIfNotPreallocated(FunctionContext*
ctx, const Array& input
int64_t bitmap_size = BitUtil::BytesForBits(length);
RETURN_NOT_OK(ctx->Allocate(bitmap_size, &validity_bitmap));
memset(validity_bitmap->mutable_data(), 0, bitmap_size);
+ } else if (input.offset() != 0) {
+ RETURN_NOT_OK(CopyBitmap(ctx->memory_pool(), validity_bitmap->data(),
input.offset(),
+ length, &validity_bitmap));
}
if (out->buffers.size() == 2) {
@@ -598,13 +739,21 @@ class CastKernel : public UnaryKernel {
FN(Int64Type, Time64Type); \
FN(Int64Type, Date64Type);
-#define DATE32_CASES(FN, IN_TYPE) FN(Date32Type, Date32Type);
+#define DATE32_CASES(FN, IN_TYPE) \
+ FN(Date32Type, Date32Type); \
+ FN(Date32Type, Date64Type);
-#define DATE64_CASES(FN, IN_TYPE) FN(Date64Type, Date64Type);
+#define DATE64_CASES(FN, IN_TYPE) \
+ FN(Date64Type, Date64Type); \
+ FN(Date64Type, Date32Type);
-#define TIME32_CASES(FN, IN_TYPE) FN(Time32Type, Time32Type);
+#define TIME32_CASES(FN, IN_TYPE) \
+ FN(Time32Type, Time32Type); \
+ FN(Time32Type, Time64Type);
-#define TIME64_CASES(FN, IN_TYPE) FN(Time64Type, Time64Type);
+#define TIME64_CASES(FN, IN_TYPE) \
+ FN(Time64Type, Time32Type); \
+ FN(Time64Type, Time64Type);
#define TIMESTAMP_CASES(FN, IN_TYPE) FN(TimestampType, TimestampType);
diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h
index 7a07512b2..d7bde20d6 100644
--- a/cpp/src/arrow/compute/cast.h
+++ b/cpp/src/arrow/compute/cast.h
@@ -34,9 +34,10 @@ class FunctionContext;
class UnaryKernel;
struct CastOptions {
- CastOptions() : allow_int_overflow(false) {}
+ CastOptions() : allow_int_overflow(false), allow_time_truncate(false) {}
bool allow_int_overflow;
+ bool allow_time_truncate;
};
/// \since 0.7.0
diff --git a/cpp/src/arrow/compute/compute-test.cc
b/cpp/src/arrow/compute/compute-test.cc
index 8a595178d..8a7ef923b 100644
--- a/cpp/src/arrow/compute/compute-test.cc
+++ b/cpp/src/arrow/compute/compute-test.cc
@@ -68,7 +68,7 @@ class TestCast : public ComputeFixture, public TestBase {
const std::shared_ptr<DataType>& out_type, const CastOptions&
options) {
std::shared_ptr<Array> result;
ASSERT_OK(Cast(&ctx_, input, out_type, options, &result));
- AssertArraysEqual(expected, *result);
+ ASSERT_ARRAYS_EQUAL(expected, *result);
}
template <typename InType, typename I_TYPE>
@@ -105,6 +105,11 @@ class TestCast : public ComputeFixture, public TestBase {
ArrayFromVector<OutType, O_TYPE>(out_type, out_values, &expected);
}
CheckPass(*input, *expected, out_type, options);
+
+ // Check a sliced variant
+ if (input->length() > 1) {
+ CheckPass(*input->Slice(1), *expected->Slice(1), out_type, options);
+ }
}
};
@@ -270,6 +275,205 @@ TEST_F(TestCast, ToIntDowncastUnsafe) {
options);
}
+TEST_F(TestCast, TimestampToTimestamp) {
+ CastOptions options;
+
+ auto CheckTimestampCast = [this](
+ const CastOptions& options, TimeUnit::type from_unit, TimeUnit::type
to_unit,
+ const std::vector<int64_t>& from_values, const std::vector<int64_t>&
to_values,
+ const std::vector<bool>& is_valid) {
+ CheckCase<TimestampType, int64_t, TimestampType, int64_t>(
+ timestamp(from_unit), from_values, is_valid, timestamp(to_unit),
to_values,
+ options);
+ };
+
+ vector<bool> is_valid = {true, false, true, true, true};
+
+ // Multiply promotions
+ vector<int64_t> v1 = {0, 100, 200, 1, 2};
+ vector<int64_t> e1 = {0, 100000, 200000, 1000, 2000};
+ CheckTimestampCast(options, TimeUnit::SECOND, TimeUnit::MILLI, v1, e1,
is_valid);
+
+ vector<int64_t> v2 = {0, 100, 200, 1, 2};
+ vector<int64_t> e2 = {0, 100000000L, 200000000L, 1000000, 2000000};
+ CheckTimestampCast(options, TimeUnit::SECOND, TimeUnit::MICRO, v2, e2,
is_valid);
+
+ vector<int64_t> v3 = {0, 100, 200, 1, 2};
+ vector<int64_t> e3 = {0, 100000000000L, 200000000000L, 1000000000L,
2000000000L};
+ CheckTimestampCast(options, TimeUnit::SECOND, TimeUnit::NANO, v3, e3,
is_valid);
+
+ vector<int64_t> v4 = {0, 100, 200, 1, 2};
+ vector<int64_t> e4 = {0, 100000, 200000, 1000, 2000};
+ CheckTimestampCast(options, TimeUnit::MILLI, TimeUnit::MICRO, v4, e4,
is_valid);
+
+ vector<int64_t> v5 = {0, 100, 200, 1, 2};
+ vector<int64_t> e5 = {0, 100000000L, 200000000L, 1000000, 2000000};
+ CheckTimestampCast(options, TimeUnit::MILLI, TimeUnit::NANO, v5, e5,
is_valid);
+
+ vector<int64_t> v6 = {0, 100, 200, 1, 2};
+ vector<int64_t> e6 = {0, 100000, 200000, 1000, 2000};
+ CheckTimestampCast(options, TimeUnit::MICRO, TimeUnit::NANO, v6, e6,
is_valid);
+
+ // Zero copy
+ std::shared_ptr<Array> arr;
+ vector<int64_t> v7 = {0, 70000, 2000, 1000, 0};
+ ArrayFromVector<TimestampType, int64_t>(timestamp(TimeUnit::SECOND),
is_valid, v7,
+ &arr);
+ CheckZeroCopy(*arr, timestamp(TimeUnit::SECOND));
+
+ // Divide, truncate
+ vector<int64_t> v8 = {0, 100123, 200456, 1123, 2456};
+ vector<int64_t> e8 = {0, 100, 200, 1, 2};
+
+ options.allow_time_truncate = true;
+ CheckTimestampCast(options, TimeUnit::MILLI, TimeUnit::SECOND, v8, e8,
is_valid);
+ CheckTimestampCast(options, TimeUnit::MICRO, TimeUnit::MILLI, v8, e8,
is_valid);
+ CheckTimestampCast(options, TimeUnit::NANO, TimeUnit::MICRO, v8, e8,
is_valid);
+
+ vector<int64_t> v9 = {0, 100123000, 200456000, 1123000, 2456000};
+ vector<int64_t> e9 = {0, 100, 200, 1, 2};
+ CheckTimestampCast(options, TimeUnit::MICRO, TimeUnit::SECOND, v9, e9,
is_valid);
+ CheckTimestampCast(options, TimeUnit::NANO, TimeUnit::MILLI, v9, e9,
is_valid);
+
+ vector<int64_t> v10 = {0, 100123000000L, 200456000000L, 1123000000L,
2456000000};
+ vector<int64_t> e10 = {0, 100, 200, 1, 2};
+ CheckTimestampCast(options, TimeUnit::NANO, TimeUnit::SECOND, v10, e10,
is_valid);
+
+ // Disallow truncate, failures
+ options.allow_time_truncate = false;
+ CheckFails<TimestampType>(timestamp(TimeUnit::MILLI), v8, is_valid,
+ timestamp(TimeUnit::SECOND), options);
+ CheckFails<TimestampType>(timestamp(TimeUnit::MICRO), v8, is_valid,
+ timestamp(TimeUnit::MILLI), options);
+ CheckFails<TimestampType>(timestamp(TimeUnit::NANO), v8, is_valid,
+ timestamp(TimeUnit::MICRO), options);
+ CheckFails<TimestampType>(timestamp(TimeUnit::MICRO), v9, is_valid,
+ timestamp(TimeUnit::SECOND), options);
+ CheckFails<TimestampType>(timestamp(TimeUnit::NANO), v9, is_valid,
+ timestamp(TimeUnit::MILLI), options);
+ CheckFails<TimestampType>(timestamp(TimeUnit::NANO), v10, is_valid,
+ timestamp(TimeUnit::SECOND), options);
+}
+
+TEST_F(TestCast, TimeToTime) {
+ CastOptions options;
+
+ vector<bool> is_valid = {true, false, true, true, true};
+
+ // Multiply promotions
+ vector<int32_t> v1 = {0, 100, 200, 1, 2};
+ vector<int32_t> e1 = {0, 100000, 200000, 1000, 2000};
+ CheckCase<Time32Type, int32_t, Time32Type, int32_t>(
+ time32(TimeUnit::SECOND), v1, is_valid, time32(TimeUnit::MILLI), e1,
options);
+
+ vector<int32_t> v2 = {0, 100, 200, 1, 2};
+ vector<int64_t> e2 = {0, 100000000L, 200000000L, 1000000, 2000000};
+ CheckCase<Time32Type, int32_t, Time64Type, int64_t>(
+ time32(TimeUnit::SECOND), v2, is_valid, time64(TimeUnit::MICRO), e2,
options);
+
+ vector<int32_t> v3 = {0, 100, 200, 1, 2};
+ vector<int64_t> e3 = {0, 100000000000L, 200000000000L, 1000000000L,
2000000000L};
+ CheckCase<Time32Type, int32_t, Time64Type, int64_t>(
+ time32(TimeUnit::SECOND), v3, is_valid, time64(TimeUnit::NANO), e3,
options);
+
+ vector<int32_t> v4 = {0, 100, 200, 1, 2};
+ vector<int64_t> e4 = {0, 100000, 200000, 1000, 2000};
+ CheckCase<Time32Type, int32_t, Time64Type, int64_t>(
+ time32(TimeUnit::MILLI), v4, is_valid, time64(TimeUnit::MICRO), e4,
options);
+
+ vector<int32_t> v5 = {0, 100, 200, 1, 2};
+ vector<int64_t> e5 = {0, 100000000L, 200000000L, 1000000, 2000000};
+ CheckCase<Time32Type, int32_t, Time64Type, int64_t>(
+ time32(TimeUnit::MILLI), v5, is_valid, time64(TimeUnit::NANO), e5,
options);
+
+ vector<int64_t> v6 = {0, 100, 200, 1, 2};
+ vector<int64_t> e6 = {0, 100000, 200000, 1000, 2000};
+ CheckCase<Time64Type, int64_t, Time64Type, int64_t>(
+ time64(TimeUnit::MICRO), v6, is_valid, time64(TimeUnit::NANO), e6,
options);
+
+ // Zero copy
+ std::shared_ptr<Array> arr;
+ vector<int64_t> v7 = {0, 70000, 2000, 1000, 0};
+ ArrayFromVector<Time64Type, int64_t>(time64(TimeUnit::MICRO), is_valid, v7,
&arr);
+ CheckZeroCopy(*arr, time64(TimeUnit::MICRO));
+
+ // Divide, truncate
+ vector<int32_t> v8 = {0, 100123, 200456, 1123, 2456};
+ vector<int32_t> e8 = {0, 100, 200, 1, 2};
+
+ options.allow_time_truncate = true;
+ CheckCase<Time32Type, int32_t, Time32Type, int32_t>(
+ time32(TimeUnit::MILLI), v8, is_valid, time32(TimeUnit::SECOND), e8,
options);
+ CheckCase<Time64Type, int32_t, Time32Type, int32_t>(
+ time64(TimeUnit::MICRO), v8, is_valid, time32(TimeUnit::MILLI), e8,
options);
+ CheckCase<Time64Type, int32_t, Time64Type, int32_t>(
+ time64(TimeUnit::NANO), v8, is_valid, time64(TimeUnit::MICRO), e8,
options);
+
+ vector<int64_t> v9 = {0, 100123000, 200456000, 1123000, 2456000};
+ vector<int32_t> e9 = {0, 100, 200, 1, 2};
+ CheckCase<Time64Type, int64_t, Time32Type, int32_t>(
+ time64(TimeUnit::MICRO), v9, is_valid, time32(TimeUnit::SECOND), e9,
options);
+ CheckCase<Time64Type, int64_t, Time32Type, int32_t>(
+ time64(TimeUnit::NANO), v9, is_valid, time32(TimeUnit::MILLI), e9,
options);
+
+ vector<int64_t> v10 = {0, 100123000000L, 200456000000L, 1123000000L,
2456000000};
+ vector<int32_t> e10 = {0, 100, 200, 1, 2};
+ CheckCase<Time64Type, int64_t, Time32Type, int32_t>(
+ time64(TimeUnit::NANO), v10, is_valid, time32(TimeUnit::SECOND), e10,
options);
+
+ // Disallow truncate, failures
+
+ options.allow_time_truncate = false;
+ CheckFails<Time32Type>(time32(TimeUnit::MILLI), v8, is_valid,
time32(TimeUnit::SECOND),
+ options);
+ CheckFails<Time64Type>(time64(TimeUnit::MICRO), v8, is_valid,
time32(TimeUnit::MILLI),
+ options);
+ CheckFails<Time64Type>(time64(TimeUnit::NANO), v8, is_valid,
time64(TimeUnit::MICRO),
+ options);
+ CheckFails<Time64Type>(time64(TimeUnit::MICRO), v9, is_valid,
time32(TimeUnit::SECOND),
+ options);
+ CheckFails<Time64Type>(time64(TimeUnit::NANO), v9, is_valid,
time32(TimeUnit::MILLI),
+ options);
+ CheckFails<Time64Type>(time64(TimeUnit::NANO), v10, is_valid,
time32(TimeUnit::SECOND),
+ options);
+}
+
+TEST_F(TestCast, DateToDate) {
+ CastOptions options;
+
+ vector<bool> is_valid = {true, false, true, true, true};
+
+ constexpr int64_t F = 86400000;
+
+ // Multiply promotion
+ vector<int32_t> v1 = {0, 100, 200, 1, 2};
+ vector<int64_t> e1 = {0, 100 * F, 200 * F, F, 2 * F};
+ CheckCase<Date32Type, int32_t, Date64Type, int64_t>(date32(), v1, is_valid,
date64(),
+ e1, options);
+
+ // Zero copy
+ std::shared_ptr<Array> arr;
+ vector<int32_t> v2 = {0, 70000, 2000, 1000, 0};
+ vector<int64_t> v3 = {0, 70000, 2000, 1000, 0};
+ ArrayFromVector<Date32Type, int32_t>(date32(), is_valid, v2, &arr);
+ CheckZeroCopy(*arr, date32());
+
+ ArrayFromVector<Date64Type, int64_t>(date64(), is_valid, v3, &arr);
+ CheckZeroCopy(*arr, date64());
+
+ // Divide, truncate
+ vector<int64_t> v8 = {0, 100 * F + 123, 200 * F + 456, F + 123, 2 * F + 456};
+ vector<int32_t> e8 = {0, 100, 200, 1, 2};
+
+ options.allow_time_truncate = true;
+ CheckCase<Date64Type, int64_t, Date32Type, int32_t>(date64(), v8, is_valid,
date32(),
+ e8, options);
+
+ // Disallow truncate, failures
+ options.allow_time_truncate = false;
+ CheckFails<Date64Type>(date64(), v8, is_valid, date32(), options);
+}
+
TEST_F(TestCast, ToDouble) {
CastOptions options;
vector<bool> is_valid = {true, false, true, true, true};
@@ -335,7 +539,7 @@ TEST_F(TestCast, FromNull) {
ASSERT_EQ(length, result->null_count());
// OK to look at bitmaps
- AssertArraysEqual(*result, *result);
+ ASSERT_ARRAYS_EQUAL(*result, *result);
}
TEST_F(TestCast, PreallocatedMemory) {
@@ -373,7 +577,7 @@ TEST_F(TestCast, PreallocatedMemory) {
std::shared_ptr<Array> expected;
ArrayFromVector<Int64Type, int64_t>(int64(), is_valid, e1, &expected);
- AssertArraysEqual(*expected, *result);
+ ASSERT_ARRAYS_EQUAL(*expected, *result);
}
template <typename TestType>
diff --git a/cpp/src/arrow/test-util.h b/cpp/src/arrow/test-util.h
index 83ebdea4a..044fb9476 100644
--- a/cpp/src/arrow/test-util.h
+++ b/cpp/src/arrow/test-util.h
@@ -281,15 +281,20 @@ Status MakeArray(const std::vector<uint8_t>& valid_bytes,
const std::vector<T>&
return builder->Finish(out);
}
-void AssertArraysEqual(const Array& expected, const Array& actual) {
- if (!actual.Equals(expected)) {
- std::stringstream pp_result;
- std::stringstream pp_expected;
+#define ASSERT_ARRAYS_EQUAL(LEFT, RIGHT)
\
+ do {
\
+ if (!(LEFT).Equals((RIGHT))) {
\
+ std::stringstream pp_result;
\
+ std::stringstream pp_expected;
\
+
\
+ EXPECT_OK(PrettyPrint(RIGHT, 0, &pp_result));
\
+ EXPECT_OK(PrettyPrint(LEFT, 0, &pp_expected));
\
+ FAIL() << "Got: \n" << pp_result.str() << "\nExpected: \n" <<
pp_expected.str(); \
+ }
\
+ } while (false)
- EXPECT_OK(PrettyPrint(actual, 0, &pp_result));
- EXPECT_OK(PrettyPrint(expected, 0, &pp_expected));
- FAIL() << "Got: \n" << pp_result.str() << "\nExpected: \n" <<
pp_expected.str();
- }
+void AssertArraysEqual(const Array& expected, const Array& actual) {
+ ASSERT_ARRAYS_EQUAL(expected, actual);
}
#define ASSERT_BATCHES_EQUAL(LEFT, RIGHT) \
diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h
index 443828423..878fdf29e 100644
--- a/cpp/src/arrow/type.h
+++ b/cpp/src/arrow/type.h
@@ -228,7 +228,11 @@ class ARROW_EXPORT FloatingPoint : public Number {
virtual Precision precision() const = 0;
};
-class ARROW_EXPORT NestedType : public DataType {
+/// \class ParametricType
+/// \brief A superclass for types having additional metadata
+class ParametricType {};
+
+class ARROW_EXPORT NestedType : public DataType, public ParametricType {
public:
using DataType::DataType;
};
@@ -444,7 +448,7 @@ class ARROW_EXPORT BinaryType : public DataType, public
NoExtraMeta {
};
// BinaryType type is represents lists of 1-byte values.
-class ARROW_EXPORT FixedSizeBinaryType : public FixedWidthType {
+class ARROW_EXPORT FixedSizeBinaryType : public FixedWidthType, public
ParametricType {
public:
static constexpr Type::type type_id = Type::FIXED_SIZE_BINARY;
@@ -611,7 +615,7 @@ static inline std::ostream& operator<<(std::ostream& os,
TimeUnit::type unit) {
return os;
}
-class ARROW_EXPORT TimeType : public FixedWidthType {
+class ARROW_EXPORT TimeType : public FixedWidthType, public ParametricType {
public:
TimeUnit::type unit() const { return unit_; }
@@ -650,7 +654,7 @@ class ARROW_EXPORT Time64Type : public TimeType {
std::string name() const override { return "time64"; }
};
-class ARROW_EXPORT TimestampType : public FixedWidthType {
+class ARROW_EXPORT TimestampType : public FixedWidthType, public
ParametricType {
public:
using Unit = TimeUnit;
diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi
index c596d2ad8..72262f0c9 100644
--- a/python/pyarrow/array.pxi
+++ b/python/pyarrow/array.pxi
@@ -260,8 +260,8 @@ cdef class Array:
type = _ensure_type(target_type)
- if not safe:
- options.allow_int_overflow = 1
+ options.allow_int_overflow = not safe
+ options.allow_time_truncate = not safe
with nogil:
check_status(Cast(_context(), self.ap[0], type.sp_type,
diff --git a/python/pyarrow/includes/libarrow.pxd
b/python/pyarrow/includes/libarrow.pxd
index 0e5d4a8ed..809bb96b7 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -747,6 +747,7 @@ cdef extern from "arrow/compute/api.h" namespace
"arrow::compute" nogil:
cdef cppclass CCastOptions" arrow::compute::CastOptions":
c_bool allow_int_overflow
+ c_bool allow_time_truncate
CStatus Cast(CFunctionContext* context, const CArray& array,
const shared_ptr[CDataType]& to_type,
diff --git a/python/pyarrow/tests/test_array.py
b/python/pyarrow/tests/test_array.py
index 418076f81..e3a4c9756 100644
--- a/python/pyarrow/tests/test_array.py
+++ b/python/pyarrow/tests/test_array.py
@@ -290,6 +290,32 @@ def test_cast_integers_unsafe():
_check_cast_case(case, safe=False)
+def test_cast_timestamp_unit():
+ # ARROW-1680
+ val = datetime.datetime.now()
+ s = pd.Series([val])
+ s_nyc = s.dt.tz_localize('tzlocal()').dt.tz_convert('America/New_York')
+
+ us_with_tz = pa.timestamp('us', tz='America/New_York')
+ arr = pa.Array.from_pandas(s_nyc, type=us_with_tz)
+
+ arr2 = pa.Array.from_pandas(s, type=pa.timestamp('us'))
+
+ assert arr[0].as_py() == s_nyc[0]
+ assert arr2[0].as_py() == s[0]
+
+ # Disallow truncation
+ arr = pa.array([123123], type='int64').cast(pa.timestamp('ms'))
+ expected = pa.array([123], type='int64').cast(pa.timestamp('s'))
+
+ target = pa.timestamp('s')
+ with pytest.raises(ValueError):
+ arr.cast(target)
+
+ result = arr.cast(target, safe=False)
+ assert result.equals(expected)
+
+
def test_cast_signed_to_unsigned():
safe_cases = [
(np.array([0, 1, 2, 3], dtype='i1'), pa.uint8(),
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
> [C++] Implement (safe and unsafe) casts between timestamps and times of
> different units
> ---------------------------------------------------------------------------------------
>
> Key: ARROW-1484
> URL: https://issues.apache.org/jira/browse/ARROW-1484
> Project: Apache Arrow
> Issue Type: New Feature
> Components: C++
> Reporter: Wes McKinney
> Assignee: Wes McKinney
> Labels: pull-request-available
> Fix For: 0.8.0
>
>
--
This message was sent by Atlassian JIRA
(v6.4.14#64029)