This is an automated email from the ASF dual-hosted git repository.

paleolimbot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-nanoarrow.git


The following commit(s) were added to refs/heads/main by this push:
     new d917f29c test: test with the `HalfFloatType` from arrow (#503)
d917f29c is described below

commit d917f29cca7c52033177d7b456d44c26b0b298c9
Author: Cocoa <[email protected]>
AuthorDate: Tue Jun 4 13:54:55 2024 +0100

    test: test with the `HalfFloatType` from arrow (#503)
    
    Hi this PR adds an extra test case that tests `ArrowFloatToHalfFloat`
    and `ArrowArrayViewGet{Double,Int,UInt}Unsafe` with the `HalfFloatType`
    from arrow.
    
    
    
https://github.com/apache/arrow/blob/255dbf990c3d3e5fb1270a2a11efe0af2be195ab/cpp/src/arrow/type.h#L704-L713
    
    And sorry that I didn't know there's a `HalfFloatType` in arrow-cpp when
    I was doing #501 🥲.
    
    ~Sadly that we cannot simply append
    `TestGetFromNumericArrayView<HalfFloatType>();` at the end of the
    `ArrayViewTestGetNumeric` test suite because we have to convert floats
    to half-floats using `ArrowFloatToHalfFloat` before calling
    `builder.Append` (otherwise we'll get weird values back in the
    subsequent `ArrowArrayViewGet{Double,Int,UInt}Unsafe` calls).~
    
    Added some simple C++ magic and we can simply append
    `TestGetFromNumericArrayView<HalfFloatType>();` at the end of the
    `ArrayViewTestGetNumeric` test suite now.
---
 src/nanoarrow/array_test.cc         | 24 ++++++++++++++++++++----
 src/nanoarrow/nanoarrow_testing.hpp |  3 +++
 src/nanoarrow/utils_test.cc         |  4 ++--
 3 files changed, 25 insertions(+), 6 deletions(-)

diff --git a/src/nanoarrow/array_test.cc b/src/nanoarrow/array_test.cc
index 35531495..ff9ffcd2 100644
--- a/src/nanoarrow/array_test.cc
+++ b/src/nanoarrow/array_test.cc
@@ -19,6 +19,7 @@
 #include <gtest/gtest.h>
 #include <cmath>
 #include <cstdint>
+#include <type_traits>
 
 #include <arrow/array.h>
 #include <arrow/array/builder_binary.h>
@@ -2423,6 +2424,20 @@ TEST(ArrayTest, ArrayViewTestSparseUnionGet) {
   ArrowArrayRelease(&array);
 }
 
+template <
+    typename TypeClass, typename ValueType,
+    typename std::enable_if<std::is_same_v<TypeClass, HalfFloatType>, 
bool>::type = true>
+auto transform_value(ValueType t) -> uint16_t {
+  return ArrowFloatToHalfFloat(t);
+}
+
+template <
+    typename TypeClass, typename ValueType,
+    typename std::enable_if<!std::is_same_v<TypeClass, HalfFloatType>, 
bool>::type = true>
+auto transform_value(ValueType t) -> ValueType {
+  return t;
+}
+
 template <typename TypeClass>
 void TestGetFromNumericArrayView() {
   struct ArrowArray array;
@@ -2434,9 +2449,9 @@ void TestGetFromNumericArrayView() {
 
   // Array with nulls
   auto builder = NumericBuilder<TypeClass>();
-  ARROW_EXPECT_OK(builder.Append(1));
+  ARROW_EXPECT_OK(builder.Append(transform_value<TypeClass>(1)));
   ARROW_EXPECT_OK(builder.AppendNulls(2));
-  ARROW_EXPECT_OK(builder.Append(4));
+  ARROW_EXPECT_OK(builder.Append(transform_value<TypeClass>(4)));
   auto maybe_arrow_array = builder.Finish();
   ARROW_EXPECT_OK(maybe_arrow_array);
   auto arrow_array = maybe_arrow_array.ValueUnsafe();
@@ -2467,8 +2482,8 @@ void TestGetFromNumericArrayView() {
 
   // Array without nulls (Arrow does not allocate the validity buffer)
   builder = NumericBuilder<TypeClass>();
-  ARROW_EXPECT_OK(builder.Append(1));
-  ARROW_EXPECT_OK(builder.Append(2));
+  ARROW_EXPECT_OK(builder.Append(transform_value<TypeClass>(1)));
+  ARROW_EXPECT_OK(builder.Append(transform_value<TypeClass>(2)));
   maybe_arrow_array = builder.Finish();
   ARROW_EXPECT_OK(maybe_arrow_array);
   arrow_array = maybe_arrow_array.ValueUnsafe();
@@ -2504,6 +2519,7 @@ TEST(ArrayViewTest, ArrayViewTestGetNumeric) {
   TestGetFromNumericArrayView<UInt32Type>();
   TestGetFromNumericArrayView<DoubleType>();
   TestGetFromNumericArrayView<FloatType>();
+  TestGetFromNumericArrayView<HalfFloatType>();
 }
 
 TEST(ArrayViewTest, ArrayViewTestGetFloat16) {
diff --git a/src/nanoarrow/nanoarrow_testing.hpp 
b/src/nanoarrow/nanoarrow_testing.hpp
index f7d2da4f..93137268 100644
--- a/src/nanoarrow/nanoarrow_testing.hpp
+++ b/src/nanoarrow/nanoarrow_testing.hpp
@@ -804,6 +804,7 @@ class TestingJSONWriter {
         }
         break;
 
+      case NANOARROW_TYPE_HALF_FLOAT:
       case NANOARROW_TYPE_FLOAT:
       case NANOARROW_TYPE_DOUBLE: {
         // JSON number to float_precision_ decimal places
@@ -2146,6 +2147,8 @@ class TestingJSONReader {
           case NANOARROW_TYPE_UINT64:
             return SetBufferInt<uint64_t, uint64_t>(data, buffer, error);
 
+          case NANOARROW_TYPE_HALF_FLOAT:
+            return SetBufferFloatingPoint<float>(data, buffer, error);
           case NANOARROW_TYPE_FLOAT:
             return SetBufferFloatingPoint<float>(data, buffer, error);
           case NANOARROW_TYPE_DOUBLE:
diff --git a/src/nanoarrow/utils_test.cc b/src/nanoarrow/utils_test.cc
index 7ff55941..24892a6e 100644
--- a/src/nanoarrow/utils_test.cc
+++ b/src/nanoarrow/utils_test.cc
@@ -547,8 +547,8 @@ TEST(DecimalTest, DecimalRoundtripBitshiftTest) {
 // https://github.com/apache/arrow/blob/main/go/arrow/float16/float16_test.go
 TEST(HalfFloatTest, FloatAndHalfFloatRoundTrip) {
   uint16_t cases_bits[] = {
-      0x8000,  0x7c00, 0xfc00, 0x3c00, 0x4000, 0xc000,
-      +0x0000, 0x5b8f, 0xdb8f, 0x48c8, 0xc8c8,
+      0x8000, 0x7c00, 0xfc00, 0x3c00, 0x4000, 0xc000,
+      0x0000, 0x5b8f, 0xdb8f, 0x48c8, 0xc8c8,
   };
   float cases_float[] = {
       -0.0, INFINITY, -INFINITY, 1, 2, -2, 0, 241.875, -241.875, 9.5625, 
-9.5625,

Reply via email to