This is an automated email from the ASF dual-hosted git repository.
paleolimbot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-nanoarrow.git
The following commit(s) were added to refs/heads/main by this push:
new d917f29c test: test with the `HalfFloatType` from arrow (#503)
d917f29c is described below
commit d917f29cca7c52033177d7b456d44c26b0b298c9
Author: Cocoa <[email protected]>
AuthorDate: Tue Jun 4 13:54:55 2024 +0100
test: test with the `HalfFloatType` from arrow (#503)
Hi this PR adds an extra test case that tests `ArrowFloatToHalfFloat`
and `ArrowArrayViewGet{Double,Int,UInt}Unsafe` with the `HalfFloatType`
from arrow.
https://github.com/apache/arrow/blob/255dbf990c3d3e5fb1270a2a11efe0af2be195ab/cpp/src/arrow/type.h#L704-L713
And sorry that I didn't know there's a `HalfFloatType` in arrow-cpp when
I was doing #501 🥲.
~Sadly that we cannot simply append
`TestGetFromNumericArrayView<HalfFloatType>();` at the end of the
`ArrayViewTestGetNumeric` test suite because we have to convert floats
to half-floats using `ArrowFloatToHalfFloat` before calling
`builder.Append` (otherwise we'll get weird values back in the
subsequent `ArrowArrayViewGet{Double,Int,UInt}Unsafe` calls).~
Added some simple C++ magic and we can simply append
`TestGetFromNumericArrayView<HalfFloatType>();` at the end of the
`ArrayViewTestGetNumeric` test suite now.
---
src/nanoarrow/array_test.cc | 24 ++++++++++++++++++++----
src/nanoarrow/nanoarrow_testing.hpp | 3 +++
src/nanoarrow/utils_test.cc | 4 ++--
3 files changed, 25 insertions(+), 6 deletions(-)
diff --git a/src/nanoarrow/array_test.cc b/src/nanoarrow/array_test.cc
index 35531495..ff9ffcd2 100644
--- a/src/nanoarrow/array_test.cc
+++ b/src/nanoarrow/array_test.cc
@@ -19,6 +19,7 @@
#include <gtest/gtest.h>
#include <cmath>
#include <cstdint>
+#include <type_traits>
#include <arrow/array.h>
#include <arrow/array/builder_binary.h>
@@ -2423,6 +2424,20 @@ TEST(ArrayTest, ArrayViewTestSparseUnionGet) {
ArrowArrayRelease(&array);
}
+template <
+ typename TypeClass, typename ValueType,
+ typename std::enable_if<std::is_same_v<TypeClass, HalfFloatType>,
bool>::type = true>
+auto transform_value(ValueType t) -> uint16_t {
+ return ArrowFloatToHalfFloat(t);
+}
+
+template <
+ typename TypeClass, typename ValueType,
+ typename std::enable_if<!std::is_same_v<TypeClass, HalfFloatType>,
bool>::type = true>
+auto transform_value(ValueType t) -> ValueType {
+ return t;
+}
+
template <typename TypeClass>
void TestGetFromNumericArrayView() {
struct ArrowArray array;
@@ -2434,9 +2449,9 @@ void TestGetFromNumericArrayView() {
// Array with nulls
auto builder = NumericBuilder<TypeClass>();
- ARROW_EXPECT_OK(builder.Append(1));
+ ARROW_EXPECT_OK(builder.Append(transform_value<TypeClass>(1)));
ARROW_EXPECT_OK(builder.AppendNulls(2));
- ARROW_EXPECT_OK(builder.Append(4));
+ ARROW_EXPECT_OK(builder.Append(transform_value<TypeClass>(4)));
auto maybe_arrow_array = builder.Finish();
ARROW_EXPECT_OK(maybe_arrow_array);
auto arrow_array = maybe_arrow_array.ValueUnsafe();
@@ -2467,8 +2482,8 @@ void TestGetFromNumericArrayView() {
// Array without nulls (Arrow does not allocate the validity buffer)
builder = NumericBuilder<TypeClass>();
- ARROW_EXPECT_OK(builder.Append(1));
- ARROW_EXPECT_OK(builder.Append(2));
+ ARROW_EXPECT_OK(builder.Append(transform_value<TypeClass>(1)));
+ ARROW_EXPECT_OK(builder.Append(transform_value<TypeClass>(2)));
maybe_arrow_array = builder.Finish();
ARROW_EXPECT_OK(maybe_arrow_array);
arrow_array = maybe_arrow_array.ValueUnsafe();
@@ -2504,6 +2519,7 @@ TEST(ArrayViewTest, ArrayViewTestGetNumeric) {
TestGetFromNumericArrayView<UInt32Type>();
TestGetFromNumericArrayView<DoubleType>();
TestGetFromNumericArrayView<FloatType>();
+ TestGetFromNumericArrayView<HalfFloatType>();
}
TEST(ArrayViewTest, ArrayViewTestGetFloat16) {
diff --git a/src/nanoarrow/nanoarrow_testing.hpp
b/src/nanoarrow/nanoarrow_testing.hpp
index f7d2da4f..93137268 100644
--- a/src/nanoarrow/nanoarrow_testing.hpp
+++ b/src/nanoarrow/nanoarrow_testing.hpp
@@ -804,6 +804,7 @@ class TestingJSONWriter {
}
break;
+ case NANOARROW_TYPE_HALF_FLOAT:
case NANOARROW_TYPE_FLOAT:
case NANOARROW_TYPE_DOUBLE: {
// JSON number to float_precision_ decimal places
@@ -2146,6 +2147,8 @@ class TestingJSONReader {
case NANOARROW_TYPE_UINT64:
return SetBufferInt<uint64_t, uint64_t>(data, buffer, error);
+ case NANOARROW_TYPE_HALF_FLOAT:
+ return SetBufferFloatingPoint<float>(data, buffer, error);
case NANOARROW_TYPE_FLOAT:
return SetBufferFloatingPoint<float>(data, buffer, error);
case NANOARROW_TYPE_DOUBLE:
diff --git a/src/nanoarrow/utils_test.cc b/src/nanoarrow/utils_test.cc
index 7ff55941..24892a6e 100644
--- a/src/nanoarrow/utils_test.cc
+++ b/src/nanoarrow/utils_test.cc
@@ -547,8 +547,8 @@ TEST(DecimalTest, DecimalRoundtripBitshiftTest) {
// https://github.com/apache/arrow/blob/main/go/arrow/float16/float16_test.go
TEST(HalfFloatTest, FloatAndHalfFloatRoundTrip) {
uint16_t cases_bits[] = {
- 0x8000, 0x7c00, 0xfc00, 0x3c00, 0x4000, 0xc000,
- +0x0000, 0x5b8f, 0xdb8f, 0x48c8, 0xc8c8,
+ 0x8000, 0x7c00, 0xfc00, 0x3c00, 0x4000, 0xc000,
+ 0x0000, 0x5b8f, 0xdb8f, 0x48c8, 0xc8c8,
};
float cases_float[] = {
-0.0, INFINITY, -INFINITY, 1, 2, -2, 0, 241.875, -241.875, 9.5625,
-9.5625,