This is an automated email from the ASF dual-hosted git repository.

paleolimbot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-nanoarrow.git


The following commit(s) were added to refs/heads/main by this push:
     new b617f44  feat(extensions/nanoarrow_ipc): Improve type coverage of 
schema field decode (#115)
b617f44 is described below

commit b617f44644ee91210402c0181560af6f9a579c50
Author: Dewey Dunnington <[email protected]>
AuthorDate: Thu Feb 23 09:30:48 2023 -0400

    feat(extensions/nanoarrow_ipc): Improve type coverage of schema field 
decode (#115)
    
    Closes #90.
---
 .../nanoarrow_ipc/src/nanoarrow/nanoarrow_ipc.c    | 454 ++++++++++++++++++++-
 .../src/nanoarrow/nanoarrow_ipc_test.cc            | 131 ++++++
 2 files changed, 574 insertions(+), 11 deletions(-)

diff --git a/extensions/nanoarrow_ipc/src/nanoarrow/nanoarrow_ipc.c 
b/extensions/nanoarrow_ipc/src/nanoarrow/nanoarrow_ipc.c
index 1ae14b9..b2d7cbe 100644
--- a/extensions/nanoarrow_ipc/src/nanoarrow/nanoarrow_ipc.c
+++ b/extensions/nanoarrow_ipc/src/nanoarrow/nanoarrow_ipc.c
@@ -16,6 +16,7 @@
 // under the License.
 
 #include <errno.h>
+#include <stdio.h>
 #include <string.h>
 
 #include "nanoarrow.h"
@@ -67,6 +68,59 @@ static inline int32_t ArrowIpcReadInt32LE(struct 
ArrowBufferView* data) {
 
 #define ns(x) FLATBUFFERS_WRAP_NAMESPACE(org_apache_arrow_flatbuf, x)
 
+static int ArrowIpcReaderSetMetadata(struct ArrowSchema* schema,
+                                     ns(KeyValue_vec_t) kv_vec,
+                                     struct ArrowError* error) {
+  int64_t n_pairs = ns(KeyValue_vec_len(kv_vec));
+  if (n_pairs == 0) {
+    return NANOARROW_OK;
+  }
+
+  if (n_pairs > 2147483647) {
+    ArrowErrorSet(error,
+                  "Expected between 0 and 2147483647 key/value pairs but found 
%ld",
+                  (long)n_pairs);
+    return EINVAL;
+  }
+
+  struct ArrowBuffer buf;
+  struct ArrowStringView key;
+  struct ArrowStringView value;
+  ns(KeyValue_table_t) kv;
+
+  int result = ArrowMetadataBuilderInit(&buf, NULL);
+  if (result != NANOARROW_OK) {
+    ArrowBufferReset(&buf);
+    ArrowErrorSet(error, "ArrowMetadataBuilderInit() failed");
+    return result;
+  }
+
+  for (int64_t i = 0; i < n_pairs; i++) {
+    kv = ns(KeyValue_vec_at(kv_vec, i));
+
+    key.data = ns(KeyValue_key(kv));
+    key.size_bytes = strlen(key.data);
+    value.data = ns(KeyValue_value(kv));
+    value.size_bytes = strlen(value.data);
+
+    result = ArrowMetadataBuilderAppend(&buf, key, value);
+    if (result != NANOARROW_OK) {
+      ArrowBufferReset(&buf);
+      ArrowErrorSet(error, "ArrowMetadataBuilderAppend() failed");
+      return result;
+    }
+  }
+
+  result = ArrowSchemaSetMetadata(schema, (const char*)buf.data);
+  ArrowBufferReset(&buf);
+  if (result != NANOARROW_OK) {
+    ArrowErrorSet(error, "ArrowSchemaSetMetadata() failed");
+    return result;
+  }
+
+  return NANOARROW_OK;
+}
+
 static int ArrowIpcReaderSetTypeSimple(struct ArrowSchema* schema, int 
nanoarrow_type,
                                        struct ArrowError* error) {
   int result = ArrowSchemaSetType(schema, nanoarrow_type);
@@ -133,8 +187,371 @@ static int ArrowIpcReaderSetTypeInt(struct ArrowSchema* 
schema,
   return ArrowIpcReaderSetTypeSimple(schema, nanoarrow_type, error);
 }
 
+static int ArrowIpcReaderSetTypeFloatingPoint(struct ArrowSchema* schema,
+                                              flatbuffers_generic_t 
type_generic,
+                                              struct ArrowError* error) {
+  ns(FloatingPoint_table_t) type = (ns(FloatingPoint_table_t))type_generic;
+  int precision = ns(FloatingPoint_precision(type));
+  switch (precision) {
+    case ns(Precision_HALF):
+      return ArrowIpcReaderSetTypeSimple(schema, NANOARROW_TYPE_HALF_FLOAT, 
error);
+    case ns(Precision_SINGLE):
+      return ArrowIpcReaderSetTypeSimple(schema, NANOARROW_TYPE_FLOAT, error);
+    case ns(Precision_DOUBLE):
+      return ArrowIpcReaderSetTypeSimple(schema, NANOARROW_TYPE_DOUBLE, error);
+    default:
+      ArrowErrorSet(error, "Unexpected FloatingPoint Precision value: %d",
+                    (int)precision);
+      return EINVAL;
+  }
+}
+
+static int ArrowIpcReaderSetTypeDecimal(struct ArrowSchema* schema,
+                                        flatbuffers_generic_t type_generic,
+                                        struct ArrowError* error) {
+  ns(Decimal_table_t) type = (ns(Decimal_table_t))type_generic;
+  int scale = ns(Decimal_scale(type));
+  int precision = ns(Decimal_precision(type));
+  int bitwidth = ns(Decimal_bitWidth(type));
+
+  int result;
+  switch (bitwidth) {
+    case 128:
+      result =
+          ArrowSchemaSetTypeDecimal(schema, NANOARROW_TYPE_DECIMAL128, 
precision, scale);
+      break;
+    case 256:
+      result =
+          ArrowSchemaSetTypeDecimal(schema, NANOARROW_TYPE_DECIMAL256, 
precision, scale);
+      break;
+    default:
+      ArrowErrorSet(error, "Unexpected Decimal bitwidth value: %d", 
(int)bitwidth);
+      return EINVAL;
+  }
+
+  if (result != NANOARROW_OK) {
+    ArrowErrorSet(error, "ArrowSchemaSetTypeDecimal() failed");
+    return result;
+  }
+
+  return NANOARROW_OK;
+}
+
+static int ArrowIpcReaderSetTypeFixedSizeBinary(struct ArrowSchema* schema,
+                                                flatbuffers_generic_t 
type_generic,
+                                                struct ArrowError* error) {
+  ns(FixedSizeBinary_table_t) type = (ns(FixedSizeBinary_table_t))type_generic;
+  int fixed_size = ns(FixedSizeBinary_byteWidth(type));
+  return ArrowSchemaSetTypeFixedSize(schema, NANOARROW_TYPE_FIXED_SIZE_BINARY,
+                                     fixed_size);
+}
+
+static int ArrowIpcReaderSetTypeDate(struct ArrowSchema* schema,
+                                     flatbuffers_generic_t type_generic,
+                                     struct ArrowError* error) {
+  ns(Date_table_t) type = (ns(Date_table_t))type_generic;
+  int date_unit = ns(Date_unit(type));
+  switch (date_unit) {
+    case ns(DateUnit_DAY):
+      return ArrowIpcReaderSetTypeSimple(schema, NANOARROW_TYPE_DATE32, error);
+    case ns(DateUnit_MILLISECOND):
+      return ArrowIpcReaderSetTypeSimple(schema, NANOARROW_TYPE_DATE64, error);
+    default:
+      ArrowErrorSet(error, "Unexpected Date DateUnit value: %d", 
(int)date_unit);
+      return EINVAL;
+  }
+}
+
+static int ArrowIpcReaderSetTypeTime(struct ArrowSchema* schema,
+                                     flatbuffers_generic_t type_generic,
+                                     struct ArrowError* error) {
+  ns(Time_table_t) type = (ns(Time_table_t))type_generic;
+  int time_unit = ns(Time_unit(type));
+  int bitwidth = ns(Time_bitWidth(type));
+  int nanoarrow_type;
+
+  switch (time_unit) {
+    case ns(TimeUnit_SECOND):
+    case ns(TimeUnit_MILLISECOND):
+      if (bitwidth != 32) {
+        ArrowErrorSet(error, "Expected bitwidth of 32 for Time TimeUnit %s but 
found %d",
+                      ns(TimeUnit_name(time_unit)), bitwidth);
+        return EINVAL;
+      }
+
+      nanoarrow_type = NANOARROW_TYPE_TIME32;
+      break;
+
+    case ns(TimeUnit_MICROSECOND):
+    case ns(TimeUnit_NANOSECOND):
+      if (bitwidth != 64) {
+        ArrowErrorSet(error, "Expected bitwidth of 64 for Time TimeUnit %s but 
found %d",
+                      ns(TimeUnit_name(time_unit)), bitwidth);
+        return EINVAL;
+      }
+
+      nanoarrow_type = NANOARROW_TYPE_TIME64;
+      break;
+
+    default:
+      ArrowErrorSet(error, "Unexpected Time TimeUnit value: %d", 
(int)time_unit);
+      return EINVAL;
+  }
+
+  int result = ArrowSchemaSetTypeDateTime(schema, nanoarrow_type, time_unit, 
NULL);
+  if (result != NANOARROW_OK) {
+    ArrowErrorSet(error, "ArrowSchemaSetTypeDateTime() failed");
+    return result;
+  }
+
+  return NANOARROW_OK;
+}
+
+static int ArrowIpcReaderSetTypeTimestamp(struct ArrowSchema* schema,
+                                          flatbuffers_generic_t type_generic,
+                                          struct ArrowError* error) {
+  ns(Timestamp_table_t) type = (ns(Timestamp_table_t))type_generic;
+  int time_unit = ns(Timestamp_unit(type));
+
+  const char* timezone = "";
+  if (ns(Timestamp_timezone_is_present(type))) {
+    timezone = ns(Timestamp_timezone_get(type));
+  }
+
+  int result =
+      ArrowSchemaSetTypeDateTime(schema, NANOARROW_TYPE_TIMESTAMP, time_unit, 
timezone);
+  if (result != NANOARROW_OK) {
+    ArrowErrorSet(error, "ArrowSchemaSetTypeDateTime() failed");
+    return result;
+  }
+
+  return NANOARROW_OK;
+}
+
+static int ArrowIpcReaderSetTypeDuration(struct ArrowSchema* schema,
+                                         flatbuffers_generic_t type_generic,
+                                         struct ArrowError* error) {
+  ns(Duration_table_t) type = (ns(Duration_table_t))type_generic;
+  int time_unit = ns(Duration_unit(type));
+
+  int result =
+      ArrowSchemaSetTypeDateTime(schema, NANOARROW_TYPE_DURATION, time_unit, 
NULL);
+  if (result != NANOARROW_OK) {
+    ArrowErrorSet(error, "ArrowSchemaSetTypeDateTime() failed");
+    return result;
+  }
+
+  return NANOARROW_OK;
+}
+
+static int ArrowIpcReaderSetTypeInterval(struct ArrowSchema* schema,
+                                         flatbuffers_generic_t type_generic,
+                                         struct ArrowError* error) {
+  ns(Interval_table_t) type = (ns(Interval_table_t))type_generic;
+  int interval_unit = ns(Interval_unit(type));
+
+  switch (interval_unit) {
+    case ns(IntervalUnit_YEAR_MONTH):
+      return ArrowIpcReaderSetTypeSimple(schema, 
NANOARROW_TYPE_INTERVAL_MONTHS, error);
+    case ns(IntervalUnit_DAY_TIME):
+      return ArrowIpcReaderSetTypeSimple(schema, 
NANOARROW_TYPE_INTERVAL_DAY_TIME, error);
+    case ns(IntervalUnit_MONTH_DAY_NANO):
+      return ArrowIpcReaderSetTypeSimple(schema, 
NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO,
+                                         error);
+    default:
+      ArrowErrorSet(error, "Unexpected Interval unit value: %d", 
(int)interval_unit);
+      return EINVAL;
+  }
+}
+
+// We can't quite use nanoarrow's built-in SchemaSet functions for nested types
+// because the IPC format allows modifying some of the defaults those 
functions assume.
+// In particular, the allocate + initialize children step is handled outside 
these
+// setters.
+static int ArrowIpcReaderSetTypeSimpleNested(struct ArrowSchema* schema,
+                                             const char* format,
+                                             struct ArrowError* error) {
+  int result = ArrowSchemaSetFormat(schema, format);
+  if (result != NANOARROW_OK) {
+    ArrowErrorSet(error, "ArrowSchemaSetFormat('%s') failed", format);
+    return result;
+  }
+
+  return NANOARROW_OK;
+}
+
+static int ArrowIpcReaderSetTypeFixedSizeList(struct ArrowSchema* schema,
+                                              flatbuffers_generic_t 
type_generic,
+                                              struct ArrowError* error) {
+  ns(FixedSizeList_table_t) type = (ns(FixedSizeList_table_t))type_generic;
+  int32_t fixed_size = ns(FixedSizeList_listSize(type));
+
+  char fixed_size_str[128];
+  int n_chars = snprintf(fixed_size_str, 128, "+w:%d", fixed_size);
+  fixed_size_str[n_chars] = '\0';
+  return ArrowIpcReaderSetTypeSimpleNested(schema, fixed_size_str, error);
+}
+
+static int ArrowIpcReaderSetTypeMap(struct ArrowSchema* schema,
+                                    flatbuffers_generic_t type_generic,
+                                    struct ArrowError* error) {
+  ns(Map_table_t) type = (ns(Map_table_t))type_generic;
+  NANOARROW_RETURN_NOT_OK(ArrowIpcReaderSetTypeSimpleNested(schema, "+m", 
error));
+
+  if (ns(Map_keysSorted(type))) {
+    schema->flags |= ARROW_FLAG_MAP_KEYS_SORTED;
+  } else {
+    schema->flags &= ~ARROW_FLAG_MAP_KEYS_SORTED;
+  }
+
+  return NANOARROW_OK;
+}
+
+static int ArrowIpcReaderSetTypeUnion(struct ArrowSchema* schema,
+                                      flatbuffers_generic_t type_generic,
+                                      int64_t n_children, struct ArrowError* 
error) {
+  ns(Union_table_t) type = (ns(Union_table_t))type_generic;
+  int union_mode = ns(Union_mode(type));
+
+  if (n_children < 0 || n_children > 127) {
+    ArrowErrorSet(error,
+                  "Expected between 0 and 127 children for Union type but 
found %ld",
+                  (long)n_children);
+    return EINVAL;
+  }
+
+  // Max valid typeIds size is 127; the longest single ID that could be 
present here
+  // is -INT_MIN (11 chars). With commas and the prefix the max size would be
+  // 1527 characters. (Any ids outside the range 0...127 are unlikely to be 
valid
+  // elsewhere but they could in theory be present here).
+  char union_types_str[2048];
+  memset(union_types_str, 0, sizeof(union_types_str));
+  char* format_cursor = union_types_str;
+  int format_out_size = sizeof(union_types_str);
+  int n_chars = 0;
+
+  const char* format_prefix;
+  switch (union_mode) {
+    case ns(UnionMode_Sparse):
+      n_chars = snprintf(format_cursor, format_out_size, "+us:");
+      format_cursor += n_chars;
+      format_out_size -= n_chars;
+      break;
+    case ns(UnionMode_Dense):
+      n_chars = snprintf(format_cursor, format_out_size, "+ud:");
+      format_cursor += n_chars;
+      format_out_size -= n_chars;
+      break;
+    default:
+      ArrowErrorSet(error, "Unexpected Union UnionMode value: %d", 
(int)union_mode);
+      return EINVAL;
+  }
+
+  if (ns(Union_typeIds_is_present(type))) {
+    flatbuffers_int32_vec_t type_ids = ns(Union_typeIds(type));
+    int64_t n_type_ids = flatbuffers_int32_vec_len(type_ids);
+
+    if (n_type_ids != n_children) {
+      ArrowErrorSet(
+          error,
+          "Expected between %ld children for Union type with %ld typeIds but 
found %ld",
+          (long)n_type_ids, (long)n_type_ids, (long)n_children);
+      return EINVAL;
+    }
+
+    if (n_type_ids > 0) {
+      n_chars = snprintf(format_cursor, format_out_size, "%d",
+                         flatbuffers_int32_vec_at(type_ids, 0));
+      format_cursor += n_chars;
+      format_out_size -= n_chars;
+
+      for (int64_t i = 1; i < n_type_ids; i++) {
+        n_chars = snprintf(format_cursor, format_out_size, ",%d",
+                           (int)flatbuffers_int32_vec_at(type_ids, i));
+        format_cursor += n_chars;
+        format_out_size -= n_chars;
+      }
+    }
+  } else if (n_children > 0) {
+    n_chars = snprintf(format_cursor, format_out_size, "0");
+    format_cursor += n_chars;
+    format_out_size -= n_chars;
+
+    for (int64_t i = 1; i < n_children; i++) {
+      n_chars = snprintf(format_cursor, format_out_size, ",%d", (int)i);
+      format_cursor += n_chars;
+      format_out_size -= n_chars;
+    }
+  }
+
+  return ArrowIpcReaderSetTypeSimpleNested(schema, union_types_str, error);
+}
+
+static int ArrowIpcReaderSetType(struct ArrowSchema* schema, ns(Field_table_t) 
field,
+                                 int64_t n_children, struct ArrowError* error) 
{
+  int type_type = ns(Field_type_type(field));
+  switch (type_type) {
+    case ns(Type_Null):
+      return ArrowIpcReaderSetTypeSimple(schema, NANOARROW_TYPE_NA, error);
+    case ns(Type_Bool):
+      return ArrowIpcReaderSetTypeSimple(schema, NANOARROW_TYPE_BOOL, error);
+    case ns(Type_Int):
+      return ArrowIpcReaderSetTypeInt(schema, ns(Field_type_get(field)), 
error);
+    case ns(Type_FloatingPoint):
+      return ArrowIpcReaderSetTypeFloatingPoint(schema, 
ns(Field_type_get(field)), error);
+    case ns(Type_Decimal):
+      return ArrowIpcReaderSetTypeDecimal(schema, ns(Field_type_get(field)), 
error);
+    case ns(Type_Binary):
+      return ArrowIpcReaderSetTypeSimple(schema, NANOARROW_TYPE_BINARY, error);
+    case ns(Type_LargeBinary):
+      return ArrowIpcReaderSetTypeSimple(schema, NANOARROW_TYPE_LARGE_BINARY, 
error);
+    case ns(Type_FixedSizeBinary):
+      return ArrowIpcReaderSetTypeFixedSizeBinary(schema, 
ns(Field_type_get(field)),
+                                                  error);
+    case ns(Type_Utf8):
+      return ArrowIpcReaderSetTypeSimple(schema, NANOARROW_TYPE_STRING, error);
+    case ns(Type_LargeUtf8):
+      return ArrowIpcReaderSetTypeSimple(schema, NANOARROW_TYPE_LARGE_STRING, 
error);
+    case ns(Type_Date):
+      return ArrowIpcReaderSetTypeDate(schema, ns(Field_type_get(field)), 
error);
+    case ns(Type_Time):
+      return ArrowIpcReaderSetTypeTime(schema, ns(Field_type_get(field)), 
error);
+    case ns(Type_Timestamp):
+      return ArrowIpcReaderSetTypeTimestamp(schema, ns(Field_type_get(field)), 
error);
+    case ns(Type_Duration):
+      return ArrowIpcReaderSetTypeDuration(schema, ns(Field_type_get(field)), 
error);
+    case ns(Type_Interval):
+      return ArrowIpcReaderSetTypeInterval(schema, ns(Field_type_get(field)), 
error);
+    case ns(Type_Struct_):
+      return ArrowIpcReaderSetTypeSimpleNested(schema, "+s", error);
+    case ns(Type_List):
+      return ArrowIpcReaderSetTypeSimpleNested(schema, "+l", error);
+    case ns(Type_LargeList):
+      return ArrowIpcReaderSetTypeSimpleNested(schema, "+L", error);
+    case ns(Type_FixedSizeList):
+      return ArrowIpcReaderSetTypeFixedSizeList(schema, 
ns(Field_type_get(field)), error);
+    case ns(Type_Map):
+      return ArrowIpcReaderSetTypeMap(schema, ns(Field_type_get(field)), 
error);
+    case ns(Type_Union):
+      return ArrowIpcReaderSetTypeUnion(schema, ns(Field_type_get(field)), 
n_children,
+                                        error);
+    default:
+      ArrowErrorSet(error, "Unrecognized Field type with value %d", 
(int)type_type);
+      return EINVAL;
+  }
+}
+
+static int ArrowIpcReaderSetChildren(struct ArrowSchema* schema, 
ns(Field_vec_t) fields,
+                                     struct ArrowError* error);
+
 static int ArrowIpcReaderSetField(struct ArrowSchema* schema, 
ns(Field_table_t) field,
                                   struct ArrowError* error) {
+  // No dictionary support yet
+  if (ns(Field_dictionary_is_present(field))) {
+    ArrowErrorSet(error, "Field DictionaryEncoding not supported");
+    return ENOTSUP;
+  }
+
   int result;
   if (ns(Field_name_is_present(field))) {
     result = ArrowSchemaSetName(schema, ns(Field_name_get(field)));
@@ -147,22 +564,35 @@ static int ArrowIpcReaderSetField(struct ArrowSchema* 
schema, ns(Field_table_t)
     return result;
   }
 
+  // Sets the schema->format and validates type-related inconsistencies
+  // that might exist in the flatbuffer
+  ns(Field_vec_t) children = ns(Field_children(field));
+  int64_t n_children = ns(Field_vec_len(children));
+
+  NANOARROW_RETURN_NOT_OK(ArrowIpcReaderSetType(schema, field, n_children, 
error));
+
+  // nanoarrow's type setters set the nullable flag by default, so we might
+  // have to unset it here.
   if (ns(Field_nullable_get(field))) {
     schema->flags |= ARROW_FLAG_NULLABLE;
+  } else {
+    schema->flags &= ~ARROW_FLAG_NULLABLE;
   }
 
-  int type_type = ns(Field_type_type(field));
-  switch (type_type) {
-    case ns(Type_Int):
-      NANOARROW_RETURN_NOT_OK(
-          ArrowIpcReaderSetTypeInt(schema, ns(Field_type_get(field)), error));
-      break;
-    default:
-      ArrowErrorSet(error, "Unrecognized Field type with value %d", 
(int)type_type);
-      return EINVAL;
+  // Children are defined separately in the flatbuffer, so we allocate, 
initialize
+  // and set them separately as well.
+  result = ArrowSchemaAllocateChildren(schema, n_children);
+  if (result != NANOARROW_OK) {
+    ArrowErrorSet(error, "ArrowSchemaAllocateChildren() failed");
+    return result;
   }
 
-  return NANOARROW_OK;
+  for (int64_t i = 0; i < n_children; i++) {
+    ArrowSchemaInit(schema->children[i]);
+  }
+
+  NANOARROW_RETURN_NOT_OK(ArrowIpcReaderSetChildren(schema, children, error));
+  return ArrowIpcReaderSetMetadata(schema, ns(Field_custom_metadata(field)), 
error);
 }
 
 static int ArrowIpcReaderSetChildren(struct ArrowSchema* schema, 
ns(Field_vec_t) fields,
@@ -229,7 +659,9 @@ static int ArrowIpcReaderDecodeSchema(struct 
ArrowIpcReader* reader,
     return result;
   }
 
-  return ArrowIpcReaderSetChildren(&reader->schema, fields, error);
+  NANOARROW_RETURN_NOT_OK(ArrowIpcReaderSetChildren(&reader->schema, fields, 
error));
+  return ArrowIpcReaderSetMetadata(&reader->schema, 
ns(Schema_custom_metadata(schema)),
+                                   error);
 }
 
 static inline int ArrowIpcReaderCheckHeader(struct ArrowIpcReader* reader,
diff --git a/extensions/nanoarrow_ipc/src/nanoarrow/nanoarrow_ipc_test.cc 
b/extensions/nanoarrow_ipc/src/nanoarrow/nanoarrow_ipc_test.cc
index dafee24..55835e8 100644
--- a/extensions/nanoarrow_ipc/src/nanoarrow/nanoarrow_ipc_test.cc
+++ b/extensions/nanoarrow_ipc/src/nanoarrow/nanoarrow_ipc_test.cc
@@ -19,6 +19,8 @@
 
 #include <arrow/array.h>
 #include <arrow/c/bridge.h>
+#include <arrow/ipc/api.h>
+#include <arrow/util/key_value_metadata.h>
 #include <gtest/gtest.h>
 
 #include "nanoarrow_ipc.h"
@@ -168,3 +170,132 @@ TEST(NanoarrowIpcTest, NanoarrowIpcDecodeSimpleSchema) {
 
   ArrowIpcReaderReset(&reader);
 }
+
+class ArrowTypeParameterizedTestFixture
+    : public ::testing::TestWithParam<std::shared_ptr<arrow::DataType>> {
+ protected:
+  std::shared_ptr<arrow::DataType> data_type;
+};
+
+TEST_P(ArrowTypeParameterizedTestFixture, NanoarrowIpcArrowTypeRoundtrip) {
+  const std::shared_ptr<arrow::DataType>& data_type = GetParam();
+  std::shared_ptr<arrow::Schema> dummy_schema =
+      arrow::schema({arrow::field("dummy_name", data_type)});
+  auto maybe_serialized = arrow::ipc::SerializeSchema(*dummy_schema);
+  ASSERT_TRUE(maybe_serialized.ok());
+
+  struct ArrowBufferView buffer_view;
+  buffer_view.data.data = maybe_serialized.ValueUnsafe()->data();
+  buffer_view.size_bytes = maybe_serialized.ValueOrDie()->size();
+
+  struct ArrowIpcReader reader;
+  ArrowIpcReaderInit(&reader);
+  ASSERT_EQ(ArrowIpcReaderVerify(&reader, buffer_view, nullptr), NANOARROW_OK);
+  EXPECT_EQ(reader.header_size_bytes, buffer_view.size_bytes);
+  EXPECT_EQ(reader.body_size_bytes, 0);
+
+  ASSERT_EQ(ArrowIpcReaderDecode(&reader, buffer_view, nullptr), NANOARROW_OK);
+  auto maybe_schema = arrow::ImportSchema(&reader.schema);
+  ASSERT_TRUE(maybe_schema.ok());
+
+  // Better failure message if we first check for string equality
+  EXPECT_EQ(maybe_schema.ValueUnsafe()->ToString(), dummy_schema->ToString());
+  EXPECT_TRUE(maybe_schema.ValueUnsafe()->Equals(dummy_schema, true));
+
+  ArrowIpcReaderReset(&reader);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+    NanoarrowIpcTest, ArrowTypeParameterizedTestFixture,
+    ::testing::Values(
+        arrow::null(), arrow::boolean(), arrow::int8(), arrow::uint8(), 
arrow::int16(),
+        arrow::uint16(), arrow::int32(), arrow::uint32(), arrow::int64(), 
arrow::uint64(),
+        arrow::utf8(), arrow::float16(), arrow::float32(), arrow::float64(),
+        arrow::decimal128(10, 3), arrow::decimal256(10, 3), 
arrow::large_utf8(),
+        arrow::binary(), arrow::large_binary(), arrow::fixed_size_binary(123),
+        arrow::date32(), arrow::date64(), 
arrow::time32(arrow::TimeUnit::SECOND),
+        arrow::time32(arrow::TimeUnit::MILLI), 
arrow::time64(arrow::TimeUnit::MICRO),
+        arrow::time64(arrow::TimeUnit::NANO), 
arrow::timestamp(arrow::TimeUnit::SECOND),
+        arrow::timestamp(arrow::TimeUnit::MILLI),
+        arrow::timestamp(arrow::TimeUnit::MICRO), 
arrow::timestamp(arrow::TimeUnit::NANO),
+        arrow::timestamp(arrow::TimeUnit::SECOND, "UTC"),
+        arrow::duration(arrow::TimeUnit::SECOND), 
arrow::duration(arrow::TimeUnit::MILLI),
+        arrow::duration(arrow::TimeUnit::MICRO), 
arrow::duration(arrow::TimeUnit::NANO),
+        arrow::month_interval(), arrow::day_time_interval(),
+        arrow::month_day_nano_interval(),
+        arrow::list(arrow::field("some_custom_name", arrow::int32())),
+        arrow::large_list(arrow::field("some_custom_name", arrow::int32())),
+        arrow::fixed_size_list(arrow::field("some_custom_name", 
arrow::int32()), 123),
+        arrow::map(arrow::utf8(), arrow::int64(), false),
+        arrow::map(arrow::utf8(), arrow::int64(), true),
+        arrow::struct_({arrow::field("col1", arrow::int32()),
+                        arrow::field("col2", arrow::utf8())}),
+        // Zero-size union doesn't roundtrip through the C Data interface until
+        // Arrow 11 (which is not yet available on all platforms)
+        // arrow::sparse_union(FieldVector()), 
arrow::dense_union(FieldVector()),
+        // No custom type IDs
+        arrow::sparse_union({arrow::field("col1", arrow::int32()),
+                             arrow::field("col2", arrow::utf8())}),
+        arrow::dense_union({arrow::field("col1", arrow::int32()),
+                            arrow::field("col2", arrow::utf8())}),
+        // With custom type IDs
+        arrow::sparse_union({arrow::field("col1", arrow::int32()),
+                             arrow::field("col2", arrow::utf8())},
+                            {126, 127}),
+        arrow::dense_union({arrow::field("col1", arrow::int32()),
+                            arrow::field("col2", arrow::utf8())},
+                           {126, 127}),
+
+        // Type with nested metadata
+        arrow::list(arrow::field("some_custom_name", arrow::int32(),
+                                 arrow::KeyValueMetadata::Make({"key1"}, 
{"value1"})))
+
+            ));
+
+class ArrowSchemaParameterizedTestFixture
+    : public ::testing::TestWithParam<std::shared_ptr<arrow::Schema>> {
+ protected:
+  std::shared_ptr<arrow::Schema> arrow_schema;
+};
+
+TEST_P(ArrowSchemaParameterizedTestFixture, NanoarrowIpcArrowSchemaRoundtrip) {
+  const std::shared_ptr<arrow::Schema>& arrow_schema = GetParam();
+  auto maybe_serialized = arrow::ipc::SerializeSchema(*arrow_schema);
+  ASSERT_TRUE(maybe_serialized.ok());
+
+  struct ArrowBufferView buffer_view;
+  buffer_view.data.data = maybe_serialized.ValueUnsafe()->data();
+  buffer_view.size_bytes = maybe_serialized.ValueOrDie()->size();
+
+  struct ArrowIpcReader reader;
+  ArrowIpcReaderInit(&reader);
+  ASSERT_EQ(ArrowIpcReaderVerify(&reader, buffer_view, nullptr), NANOARROW_OK);
+  EXPECT_EQ(reader.header_size_bytes, buffer_view.size_bytes);
+  EXPECT_EQ(reader.body_size_bytes, 0);
+
+  ASSERT_EQ(ArrowIpcReaderDecode(&reader, buffer_view, nullptr), NANOARROW_OK);
+  auto maybe_schema = arrow::ImportSchema(&reader.schema);
+  ASSERT_TRUE(maybe_schema.ok());
+
+  // Better failure message if we first check for string equality
+  EXPECT_EQ(maybe_schema.ValueUnsafe()->ToString(), arrow_schema->ToString());
+  EXPECT_TRUE(maybe_schema.ValueUnsafe()->Equals(arrow_schema, true));
+
+  ArrowIpcReaderReset(&reader);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+    NanoarrowIpcTest, ArrowSchemaParameterizedTestFixture,
+    ::testing::Values(
+        // Empty
+        arrow::schema({}),
+        // One
+        arrow::schema({arrow::field("some_name", arrow::int32())}),
+        // Field metadata
+        arrow::schema({arrow::field(
+            "some_name", arrow::int32(),
+            arrow::KeyValueMetadata::Make({"key1", "key2"}, {"value1", 
"value2"}))}),
+        // Schema metadata
+        arrow::schema({}, arrow::KeyValueMetadata::Make({"key1"}, {"value1"})),
+        // Non-nullable field
+        arrow::schema({arrow::field("some_name", arrow::int32(), false)})));

Reply via email to