This is an automated email from the ASF dual-hosted git repository.

raulcd pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-swift.git

commit eda41f856c83e481f45cb55707e057e4e56e647c
Author: abandy <[email protected]>
AuthorDate: Tue May 27 21:04:41 2025 -0400

    GH-43170: [Swift] Add StructArray support to ArrowWriter (#43439)
    
    ### Rationale for this change
    Support for Struct type has been added to the Swift ArrayReader.  This 
change adds this support to the ArrowWriter.
    
    ### What changes are included in this PR?
    Updates to the ArrowWriter to support the Struct type.
    
    ### Are these changes tested?
    Yes, test included in PR
    
    * GitHub Issue: #43170
    
    Authored-by: Alva Bandy <[email protected]>
    Signed-off-by: Sutou Kouhei <[email protected]>
---
 Arrow/Sources/Arrow/ArrowWriter.swift       | 107 ++++++++++----
 Arrow/Sources/Arrow/ArrowWriterHelper.swift |  61 ++++----
 Arrow/Sources/Arrow/ProtoUtil.swift         |  10 +-
 Arrow/Tests/ArrowTests/IPCTests.swift       | 214 +++++++++++++++++++++++-----
 Arrow/Tests/ArrowTests/TableTests.swift     |  49 +++++++
 5 files changed, 347 insertions(+), 94 deletions(-)

diff --git a/Arrow/Sources/Arrow/ArrowWriter.swift 
b/Arrow/Sources/Arrow/ArrowWriter.swift
index 54581ba..3aa25b6 100644
--- a/Arrow/Sources/Arrow/ArrowWriter.swift
+++ b/Arrow/Sources/Arrow/ArrowWriter.swift
@@ -71,11 +71,30 @@ public class ArrowWriter { // swiftlint:disable:this 
type_body_length
     public init() {}
 
     private func writeField(_ fbb: inout FlatBufferBuilder, field: ArrowField) 
-> Result<Offset, ArrowError> {
+        var fieldsOffset: Offset?
+        if let nestedField = field.type as? ArrowNestedType {
+            var offsets = [Offset]()
+            for field in nestedField.fields {
+                switch writeField(&fbb, field: field) {
+                case .success(let offset):
+                    offsets.append(offset)
+                case .failure(let error):
+                    return .failure(error)
+                }
+            }
+
+            fieldsOffset = fbb.createVector(ofOffsets: offsets)
+        }
+
         let nameOffset = fbb.create(string: field.name)
         let fieldTypeOffsetResult = toFBType(&fbb, arrowType: field.type)
         let startOffset = org_apache_arrow_flatbuf_Field.startField(&fbb)
         org_apache_arrow_flatbuf_Field.add(name: nameOffset, &fbb)
         org_apache_arrow_flatbuf_Field.add(nullable: field.isNullable, &fbb)
+        if let childrenOffset = fieldsOffset {
+            org_apache_arrow_flatbuf_Field.addVectorOf(children: 
childrenOffset, &fbb)
+        }
+
         switch toFBTypeEnum(field.type) {
         case .success(let type):
             org_apache_arrow_flatbuf_Field.add(typeType: type, &fbb)
@@ -101,7 +120,6 @@ public class ArrowWriter { // swiftlint:disable:this 
type_body_length
             case .failure(let error):
                 return .failure(error)
             }
-
         }
 
         let fieldsOffset: Offset = fbb.createVector(ofOffsets: fieldOffsets)
@@ -126,7 +144,7 @@ public class ArrowWriter { // swiftlint:disable:this 
type_body_length
                 withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) 
{writer.append(Data($0))}
                 withUnsafeBytes(of: rbResult.1.o.littleEndian) 
{writer.append(Data($0))}
                 writer.append(rbResult.0)
-                switch writeRecordBatchData(&writer, batch: batch) {
+                switch writeRecordBatchData(&writer, fields: 
batch.schema.fields, columns: batch.columns) {
                 case .success:
                     rbBlocks.append(
                         org_apache_arrow_flatbuf_Block(offset: 
Int64(startIndex),
@@ -143,37 +161,59 @@ public class ArrowWriter { // swiftlint:disable:this 
type_body_length
         return .success(rbBlocks)
     }
 
-    private func writeRecordBatch(batch: RecordBatch) -> Result<(Data, 
Offset), ArrowError> {
-        let schema = batch.schema
-        var fbb = FlatBufferBuilder()
-
-        // write out field nodes
-        var fieldNodeOffsets = [Offset]()
-        fbb.startVector(schema.fields.count, elementSize: 
MemoryLayout<org_apache_arrow_flatbuf_FieldNode>.size)
-        for index in (0 ..< schema.fields.count).reversed() {
-            let column = batch.column(index)
+    private func writeFieldNodes(_ fields: [ArrowField], columns: 
[ArrowArrayHolder], offsets: inout [Offset],
+                                 fbb: inout FlatBufferBuilder) {
+        for index in (0 ..< fields.count).reversed() {
+            let column = columns[index]
             let fieldNode =
                 org_apache_arrow_flatbuf_FieldNode(length: 
Int64(column.length),
                                                    nullCount: 
Int64(column.nullCount))
-            fieldNodeOffsets.append(fbb.create(struct: fieldNode))
+            offsets.append(fbb.create(struct: fieldNode))
+            if let nestedType = column.type as? ArrowNestedType {
+                let structArray = column.array as? StructArray
+                writeFieldNodes(nestedType.fields, columns: 
structArray!.arrowFields!, offsets: &offsets, fbb: &fbb)
+            }
         }
+    }
 
-        let nodeOffset = fbb.endVector(len: schema.fields.count)
-
-        // write out buffers
-        var buffers = [org_apache_arrow_flatbuf_Buffer]()
-        var bufferOffset = Int(0)
-        for index in 0 ..< batch.schema.fields.count {
-            let column = batch.column(index)
+    private func writeBufferInfo(_ fields: [ArrowField],
+                                 columns: [ArrowArrayHolder],
+                                 bufferOffset: inout Int,
+                                 buffers: inout 
[org_apache_arrow_flatbuf_Buffer],
+                                 fbb: inout FlatBufferBuilder) {
+        for index in 0 ..< fields.count {
+            let column = columns[index]
             let colBufferDataSizes = column.getBufferDataSizes()
             for var bufferDataSize in colBufferDataSizes {
                 bufferDataSize = getPadForAlignment(bufferDataSize)
                 let buffer = org_apache_arrow_flatbuf_Buffer(offset: 
Int64(bufferOffset), length: Int64(bufferDataSize))
                 buffers.append(buffer)
                 bufferOffset += bufferDataSize
+                if let nestedType = column.type as? ArrowNestedType {
+                    let structArray = column.array as? StructArray
+                    writeBufferInfo(nestedType.fields, columns: 
structArray!.arrowFields!,
+                                    bufferOffset: &bufferOffset, buffers: 
&buffers, fbb: &fbb)
+                }
             }
         }
+    }
 
+    private func writeRecordBatch(batch: RecordBatch) -> Result<(Data, 
Offset), ArrowError> {
+        let schema = batch.schema
+        var fbb = FlatBufferBuilder()
+
+        // write out field nodes
+        var fieldNodeOffsets = [Offset]()
+        fbb.startVector(schema.fields.count, elementSize: 
MemoryLayout<org_apache_arrow_flatbuf_FieldNode>.size)
+        writeFieldNodes(schema.fields, columns: batch.columns, offsets: 
&fieldNodeOffsets, fbb: &fbb)
+        let nodeOffset = fbb.endVector(len: fieldNodeOffsets.count)
+
+        // write out buffers
+        var buffers = [org_apache_arrow_flatbuf_Buffer]()
+        var bufferOffset = Int(0)
+        writeBufferInfo(schema.fields, columns: batch.columns,
+                        bufferOffset: &bufferOffset, buffers: &buffers,
+                        fbb: &fbb)
         
org_apache_arrow_flatbuf_RecordBatch.startVectorOfBuffers(batch.schema.fields.count,
 in: &fbb)
         for buffer in buffers.reversed() {
             fbb.create(struct: buffer)
@@ -196,13 +236,28 @@ public class ArrowWriter { // swiftlint:disable:this 
type_body_length
         return .success((fbb.data, Offset(offset: UInt32(fbb.data.count))))
     }
 
-    private func writeRecordBatchData(_ writer: inout DataWriter, batch: 
RecordBatch) -> Result<Bool, ArrowError> {
-        for index in 0 ..< batch.schema.fields.count {
-            let column = batch.column(index)
+    private func writeRecordBatchData(
+        _ writer: inout DataWriter, fields: [ArrowField],
+        columns: [ArrowArrayHolder])
+    -> Result<Bool, ArrowError> {
+        for index in 0 ..< fields.count {
+            let column = columns[index]
             let colBufferData = column.getBufferData()
             for var bufferData in colBufferData {
                 addPadForAlignment(&bufferData)
                 writer.append(bufferData)
+                if let nestedType = column.type as? ArrowNestedType {
+                    guard let structArray = column.array as? StructArray else {
+                        return .failure(.invalid("Struct type array expected 
for nested type"))
+                    }
+
+                    switch writeRecordBatchData(&writer, fields: 
nestedType.fields, columns: structArray.arrowFields!) {
+                    case .success:
+                        continue
+                    case .failure(let error):
+                        return .failure(error)
+                    }
+                }
             }
         }
 
@@ -226,11 +281,10 @@ public class ArrowWriter { // swiftlint:disable:this 
type_body_length
             org_apache_arrow_flatbuf_Footer.addVectorOf(recordBatches: 
rbBlkEnd, &fbb)
             let footerOffset = org_apache_arrow_flatbuf_Footer.endFooter(&fbb, 
start: footerStartOffset)
             fbb.finish(offset: footerOffset)
+            return .success(fbb.data)
         case .failure(let error):
             return .failure(error)
         }
-
-        return .success(fbb.data)
     }
 
     private func writeFile(_ writer: inout DataWriter, info: ArrowWriter.Info) 
-> Result<Bool, ArrowError> {
@@ -265,7 +319,7 @@ public class ArrowWriter { // swiftlint:disable:this 
type_body_length
         return .success(true)
     }
 
-    public func writeSteaming(_ info: ArrowWriter.Info) -> Result<Data, 
ArrowError> {
+    public func writeStreaming(_ info: ArrowWriter.Info) -> Result<Data, 
ArrowError> {
         let writer: any DataWriter = InMemDataWriter()
         switch toMessage(info.schema) {
         case .success(let schemaData):
@@ -343,7 +397,7 @@ public class ArrowWriter { // swiftlint:disable:this 
type_body_length
             writer.append(message.0)
             addPadForAlignment(&writer)
             var dataWriter: any DataWriter = InMemDataWriter()
-            switch writeRecordBatchData(&dataWriter, batch: batch) {
+            switch writeRecordBatchData(&dataWriter, fields: 
batch.schema.fields, columns: batch.columns) {
             case .success:
                 return .success([
                     (writer as! InMemDataWriter).data, // 
swiftlint:disable:this force_cast
@@ -377,3 +431,4 @@ public class ArrowWriter { // swiftlint:disable:this 
type_body_length
         return .success(fbb.data)
     }
 }
+// swiftlint:disable:this file_length
diff --git a/Arrow/Sources/Arrow/ArrowWriterHelper.swift 
b/Arrow/Sources/Arrow/ArrowWriterHelper.swift
index fdc72ef..4d63192 100644
--- a/Arrow/Sources/Arrow/ArrowWriterHelper.swift
+++ b/Arrow/Sources/Arrow/ArrowWriterHelper.swift
@@ -25,67 +25,69 @@ extension Data {
 }
 
 func toFBTypeEnum(_ arrowType: ArrowType) -> 
Result<org_apache_arrow_flatbuf_Type_, ArrowError> {
-    let infoType = arrowType.info
-    if infoType == ArrowType.ArrowInt8 || infoType == ArrowType.ArrowInt16 ||
-        infoType == ArrowType.ArrowInt64 || infoType == ArrowType.ArrowUInt8 ||
-        infoType == ArrowType.ArrowUInt16 || infoType == ArrowType.ArrowUInt32 
||
-        infoType == ArrowType.ArrowUInt64 || infoType == ArrowType.ArrowInt32 {
+    let typeId = arrowType.id
+    switch typeId {
+    case .int8, .int16, .int32, .int64, .uint8, .uint16, .uint32, .uint64:
         return .success(org_apache_arrow_flatbuf_Type_.int)
-    } else if infoType == ArrowType.ArrowFloat || infoType == 
ArrowType.ArrowDouble {
+    case .float, .double:
         return .success(org_apache_arrow_flatbuf_Type_.floatingpoint)
-    } else if infoType == ArrowType.ArrowString {
+    case .string:
         return .success(org_apache_arrow_flatbuf_Type_.utf8)
-    } else if infoType == ArrowType.ArrowBinary {
+    case .binary:
         return .success(org_apache_arrow_flatbuf_Type_.binary)
-    } else if infoType == ArrowType.ArrowBool {
+    case .boolean:
         return .success(org_apache_arrow_flatbuf_Type_.bool)
-    } else if infoType == ArrowType.ArrowDate32 || infoType == 
ArrowType.ArrowDate64 {
+    case .date32, .date64:
         return .success(org_apache_arrow_flatbuf_Type_.date)
-    } else if infoType == ArrowType.ArrowTime32 || infoType == 
ArrowType.ArrowTime64 {
+    case .time32, .time64:
         return .success(org_apache_arrow_flatbuf_Type_.time)
+    case .strct:
+        return .success(org_apache_arrow_flatbuf_Type_.struct_)
+    default:
+        return .failure(.unknownType("Unable to find flatbuf type for Arrow 
type: \(typeId)"))
     }
-    return .failure(.unknownType("Unable to find flatbuf type for Arrow type: 
\(infoType)"))
 }
 
-func toFBType( // swiftlint:disable:this cyclomatic_complexity
+func toFBType( // swiftlint:disable:this cyclomatic_complexity 
function_body_length
     _ fbb: inout FlatBufferBuilder,
     arrowType: ArrowType
 ) -> Result<Offset, ArrowError> {
     let infoType = arrowType.info
-    if infoType == ArrowType.ArrowInt8 || infoType == ArrowType.ArrowUInt8 {
+    switch arrowType.id {
+    case .int8, .uint8:
         return .success(org_apache_arrow_flatbuf_Int.createInt(
             &fbb, bitWidth: 8, isSigned: infoType == ArrowType.ArrowInt8))
-    } else if infoType == ArrowType.ArrowInt16 || infoType == 
ArrowType.ArrowUInt16 {
+    case .int16, .uint16:
         return .success(org_apache_arrow_flatbuf_Int.createInt(
             &fbb, bitWidth: 16, isSigned: infoType == ArrowType.ArrowInt16))
-    } else if infoType == ArrowType.ArrowInt32 || infoType == 
ArrowType.ArrowUInt32 {
+    case .int32, .uint32:
         return .success(org_apache_arrow_flatbuf_Int.createInt(
             &fbb, bitWidth: 32, isSigned: infoType == ArrowType.ArrowInt32))
-    } else if infoType == ArrowType.ArrowInt64 || infoType == 
ArrowType.ArrowUInt64 {
+    case .int64, .uint64:
         return .success(org_apache_arrow_flatbuf_Int.createInt(
             &fbb, bitWidth: 64, isSigned: infoType == ArrowType.ArrowInt64))
-    } else if infoType == ArrowType.ArrowFloat {
+    case .float:
         return 
.success(org_apache_arrow_flatbuf_FloatingPoint.createFloatingPoint(&fbb, 
precision: .single))
-    } else if infoType == ArrowType.ArrowDouble {
+    case .double:
         return 
.success(org_apache_arrow_flatbuf_FloatingPoint.createFloatingPoint(&fbb, 
precision: .double))
-    } else if infoType == ArrowType.ArrowString {
+    case .string:
         return .success(org_apache_arrow_flatbuf_Utf8.endUtf8(
             &fbb, start: org_apache_arrow_flatbuf_Utf8.startUtf8(&fbb)))
-    } else if infoType == ArrowType.ArrowBinary {
+    case .binary:
         return .success(org_apache_arrow_flatbuf_Binary.endBinary(
             &fbb, start: org_apache_arrow_flatbuf_Binary.startBinary(&fbb)))
-    } else if infoType == ArrowType.ArrowBool {
+    case .boolean:
         return .success(org_apache_arrow_flatbuf_Bool.endBool(
             &fbb, start: org_apache_arrow_flatbuf_Bool.startBool(&fbb)))
-    } else if infoType == ArrowType.ArrowDate32 {
+    case .date32:
         let startOffset = org_apache_arrow_flatbuf_Date.startDate(&fbb)
         org_apache_arrow_flatbuf_Date.add(unit: .day, &fbb)
         return .success(org_apache_arrow_flatbuf_Date.endDate(&fbb, start: 
startOffset))
-    } else if infoType == ArrowType.ArrowDate64 {
+    case .date64:
         let startOffset = org_apache_arrow_flatbuf_Date.startDate(&fbb)
         org_apache_arrow_flatbuf_Date.add(unit: .millisecond, &fbb)
         return .success(org_apache_arrow_flatbuf_Date.endDate(&fbb, start: 
startOffset))
-    } else if infoType == ArrowType.ArrowTime32 {
+    case .time32:
         let startOffset = org_apache_arrow_flatbuf_Time.startTime(&fbb)
         if let timeType = arrowType as? ArrowTypeTime32 {
             org_apache_arrow_flatbuf_Time.add(unit: timeType.unit == .seconds 
? .second : .millisecond, &fbb)
@@ -93,7 +95,7 @@ func toFBType( // swiftlint:disable:this cyclomatic_complexity
         }
 
         return .failure(.invalid("Unable to case to Time32"))
-    } else if infoType == ArrowType.ArrowTime64 {
+    case .time64:
         let startOffset = org_apache_arrow_flatbuf_Time.startTime(&fbb)
         if let timeType = arrowType as? ArrowTypeTime64 {
             org_apache_arrow_flatbuf_Time.add(unit: timeType.unit == 
.microseconds ? .microsecond : .nanosecond, &fbb)
@@ -101,9 +103,12 @@ func toFBType( // swiftlint:disable:this 
cyclomatic_complexity
         }
 
         return .failure(.invalid("Unable to case to Time64"))
+    case .strct:
+        let startOffset = org_apache_arrow_flatbuf_Struct_.startStruct_(&fbb)
+        return .success(org_apache_arrow_flatbuf_Struct_.endStruct_(&fbb, 
start: startOffset))
+    default:
+        return .failure(.unknownType("Unable to add flatbuf type for Arrow 
type: \(infoType)"))
     }
-
-    return .failure(.unknownType("Unable to add flatbuf type for Arrow type: 
\(infoType)"))
 }
 
 func addPadForAlignment(_ data: inout Data, alignment: Int = 8) {
diff --git a/Arrow/Sources/Arrow/ProtoUtil.swift 
b/Arrow/Sources/Arrow/ProtoUtil.swift
index ac61030..88cfb0b 100644
--- a/Arrow/Sources/Arrow/ProtoUtil.swift
+++ b/Arrow/Sources/Arrow/ProtoUtil.swift
@@ -17,7 +17,7 @@
 
 import Foundation
 
-func fromProto( // swiftlint:disable:this cyclomatic_complexity
+func fromProto( // swiftlint:disable:this cyclomatic_complexity 
function_body_length
     field: org_apache_arrow_flatbuf_Field
 ) -> ArrowField {
     let type = field.typeType
@@ -65,7 +65,13 @@ func fromProto( // swiftlint:disable:this 
cyclomatic_complexity
             arrowType = ArrowTypeTime64(arrowUnit)
         }
     case .struct_:
-        arrowType = ArrowType(ArrowType.ArrowStruct)
+        var children = [ArrowField]()
+        for index in 0..<field.childrenCount {
+            let childField = field.children(at: index)!
+            children.append(fromProto(field: childField))
+        }
+
+        arrowType = ArrowNestedType(ArrowType.ArrowStruct, fields: children)
     default:
         arrowType = ArrowType(ArrowType.ArrowUnknown)
     }
diff --git a/Arrow/Tests/ArrowTests/IPCTests.swift 
b/Arrow/Tests/ArrowTests/IPCTests.swift
index 703490d..26f38ce 100644
--- a/Arrow/Tests/ArrowTests/IPCTests.swift
+++ b/Arrow/Tests/ArrowTests/IPCTests.swift
@@ -19,6 +19,24 @@ import XCTest
 import FlatBuffers
 @testable import Arrow
 
+let currentDate = Date.now
+class StructTest {
+    var field0: Bool = false
+    var field1: Int8 = 0
+    var field2: Int16 = 0
+    var field: Int32 = 0
+    var field4: Int64 = 0
+    var field5: UInt8 = 0
+    var field6: UInt16 = 0
+    var field7: UInt32 = 0
+    var field8: UInt64 = 0
+    var field9: Double = 0
+    var field10: Float = 0
+    var field11: String = ""
+    var field12 = Data()
+    var field13: Date = currentDate
+}
+
 @discardableResult
 func checkBoolRecordBatch(_ result: Result<ArrowReader.ArrowReaderResult, 
ArrowError>) throws -> [RecordBatch] {
     let recordBatches: [RecordBatch]
@@ -55,6 +73,37 @@ func checkBoolRecordBatch(_ result: 
Result<ArrowReader.ArrowReaderResult, ArrowE
     return recordBatches
 }
 
+@discardableResult
+func checkStructRecordBatch(_ result: Result<ArrowReader.ArrowReaderResult, 
ArrowError>) throws -> [RecordBatch] {
+    let recordBatches: [RecordBatch]
+    switch result {
+    case .success(let result):
+        recordBatches = result.batches
+    case .failure(let error):
+        throw error
+    }
+
+    XCTAssertEqual(recordBatches.count, 1)
+    for recordBatch in recordBatches {
+        XCTAssertEqual(recordBatch.length, 3)
+        XCTAssertEqual(recordBatch.columns.count, 1)
+        XCTAssertEqual(recordBatch.schema.fields.count, 1)
+        XCTAssertEqual(recordBatch.schema.fields[0].name, "my struct")
+        XCTAssertEqual(recordBatch.schema.fields[0].type.id, .strct)
+        let structArray = recordBatch.columns[0].array as? StructArray
+        XCTAssertEqual(structArray!.arrowFields!.count, 2)
+        XCTAssertEqual(structArray!.arrowFields![0].type.id, .string)
+        XCTAssertEqual(structArray!.arrowFields![1].type.id, .boolean)
+        let column = recordBatch.columns[0]
+        let str = column.array as? AsString
+        XCTAssertEqual("\(str!.asString(0))", "{0,false}")
+        XCTAssertEqual("\(str!.asString(1))", "{1,true}")
+        XCTAssertTrue(column.array.asAny(2) == nil)
+    }
+
+    return recordBatches
+}
+
 func currentDirectory(path: String = #file) -> URL {
     return URL(fileURLWithPath: path).deletingLastPathComponent()
 }
@@ -69,6 +118,47 @@ func makeSchema() -> ArrowSchema {
         .finish()
 }
 
+func makeStructSchema() -> ArrowSchema {
+    let testObj = StructTest()
+    var fields = [ArrowField]()
+    let buildStructType = {() -> ArrowNestedType in
+        let mirror = Mirror(reflecting: testObj)
+        for (property, value) in mirror.children {
+            let arrowType = ArrowType(ArrowType.infoForType(type(of: value)))
+            fields.append(ArrowField(property!, type: arrowType, isNullable: 
true))
+        }
+
+        return ArrowNestedType(ArrowType.ArrowStruct, fields: fields)
+    }
+
+    return ArrowSchema.Builder()
+        .addField("struct1", type: buildStructType(), isNullable: true)
+        .finish()
+}
+
+func makeStructRecordBatch() throws -> RecordBatch {
+    let testData = StructTest()
+    let dateNow = Date.now
+    let structBuilder = try 
ArrowArrayBuilders.loadStructArrayBuilderForType(testData)
+    structBuilder.append([true, Int8(1), Int16(2), Int32(3), Int64(4),
+                          UInt8(5), UInt16(6), UInt32(7), UInt64(8), 
Double(9.9),
+                          Float(10.10), "11", Data("12".utf8), dateNow])
+    structBuilder.append(nil)
+    structBuilder.append([true, Int8(13), Int16(14), Int32(15), Int64(16),
+                          UInt8(17), UInt16(18), UInt32(19), UInt64(20), 
Double(21.21),
+                          Float(22.22), "23", Data("24".utf8), dateNow])
+    let structHolder = ArrowArrayHolderImpl(try structBuilder.finish())
+    let result = RecordBatch.Builder()
+        .addColumn("struct1", arrowArray: structHolder)
+        .finish()
+    switch result {
+    case .success(let recordBatch):
+        return recordBatch
+    case .failure(let error):
+        throw error
+    }
+}
+
 func makeRecordBatch() throws -> RecordBatch {
     let uint8Builder: NumberArrayBuilder<UInt8> = try 
ArrowArrayBuilders.loadNumberArrayBuilder()
     uint8Builder.append(10)
@@ -124,7 +214,7 @@ final class IPCStreamReaderTests: XCTestCase {
         let recordBatch = try makeRecordBatch()
         let arrowWriter = ArrowWriter()
         let writerInfo = ArrowWriter.Info(.recordbatch, schema: schema, 
batches: [recordBatch])
-        switch arrowWriter.writeSteaming(writerInfo) {
+        switch arrowWriter.writeStreaming(writerInfo) {
         case .success(let writeData):
             let arrowReader = ArrowReader()
             switch arrowReader.readStreaming(writeData) {
@@ -173,43 +263,6 @@ final class IPCStreamReaderTests: XCTestCase {
 }
 
 final class IPCFileReaderTests: XCTestCase { // swiftlint:disable:this 
type_body_length
-    func testFileReader_struct() throws {
-        let fileURL = 
currentDirectory().appendingPathComponent("../../testdata_struct.arrow")
-        let arrowReader = ArrowReader()
-        let result = arrowReader.fromFile(fileURL)
-        let recordBatches: [RecordBatch]
-        switch result {
-        case .success(let result):
-            recordBatches = result.batches
-        case .failure(let error):
-            throw error
-        }
-
-        XCTAssertEqual(recordBatches.count, 1)
-        for recordBatch in recordBatches {
-            XCTAssertEqual(recordBatch.length, 3)
-            XCTAssertEqual(recordBatch.columns.count, 1)
-            XCTAssertEqual(recordBatch.schema.fields.count, 1)
-            XCTAssertEqual(recordBatch.schema.fields[0].type.info, 
ArrowType.ArrowStruct)
-            let column = recordBatch.columns[0]
-            XCTAssertNotNil(column.array as? StructArray)
-            if let structArray = column.array as? StructArray {
-                XCTAssertEqual(structArray.arrowFields?.count, 2)
-                XCTAssertEqual(structArray.arrowFields?[0].type.info, 
ArrowType.ArrowString)
-                XCTAssertEqual(structArray.arrowFields?[1].type.info, 
ArrowType.ArrowBool)
-                for index in 0..<structArray.length {
-                    if index == 2 {
-                        XCTAssertNil(structArray[index])
-                    } else {
-                        XCTAssertEqual(structArray[index]?[0] as? String, 
"\(index)")
-                        XCTAssertEqual(structArray[index]?[1] as? Bool, index 
% 2 == 1)
-                    }
-                }
-            }
-
-        }
-    }
-
     func testFileReader_double() throws {
         let fileURL = 
currentDirectory().appendingPathComponent("../../testdata_double.arrow")
         let arrowReader = ArrowReader()
@@ -275,6 +328,37 @@ final class IPCFileReaderTests: XCTestCase { // 
swiftlint:disable:this type_body
         }
     }
 
+    func testFileReader_struct() throws {
+        let fileURL = 
currentDirectory().appendingPathComponent("../../testdata_struct.arrow")
+        let arrowReader = ArrowReader()
+        try checkStructRecordBatch(arrowReader.fromFile(fileURL))
+    }
+
+    func testFileWriter_struct() throws {
+        // read existing file
+        let fileURL = 
currentDirectory().appendingPathComponent("../../testdata_struct.arrow")
+        let arrowReader = ArrowReader()
+        let fileRBs = try checkStructRecordBatch(arrowReader.fromFile(fileURL))
+        let arrowWriter = ArrowWriter()
+        // write data from file to a stream
+        let writerInfo = ArrowWriter.Info(.recordbatch, schema: 
fileRBs[0].schema, batches: fileRBs)
+        switch arrowWriter.writeFile(writerInfo) {
+        case .success(let writeData):
+            // read stream back into recordbatches
+            try checkStructRecordBatch(arrowReader.readFile(writeData))
+        case .failure(let error):
+            throw error
+        }
+        // write file record batches to another file
+        let outputUrl = 
currentDirectory().appendingPathComponent("../../testfilewriter_struct.arrow")
+        switch arrowWriter.toFile(outputUrl, info: writerInfo) {
+        case .success:
+            try checkStructRecordBatch(arrowReader.fromFile(outputUrl))
+        case .failure(let error):
+            throw error
+        }
+    }
+
     func testRBInMemoryToFromStream() throws {
         // read existing file
         let schema = makeSchema()
@@ -412,6 +496,60 @@ final class IPCFileReaderTests: XCTestCase { // 
swiftlint:disable:this type_body
         }
     }
 
+    func testStructRBInMemoryToFromStream() throws {
+        // read existing file
+        let schema = makeStructSchema()
+        let recordBatch = try makeStructRecordBatch()
+        let arrowWriter = ArrowWriter()
+        let writerInfo = ArrowWriter.Info(.recordbatch, schema: schema, 
batches: [recordBatch])
+        switch arrowWriter.writeStreaming(writerInfo) {
+        case .success(let writeData):
+            let arrowReader = ArrowReader()
+            switch arrowReader.readStreaming(writeData) {
+            case .success(let result):
+                let recordBatches = result.batches
+                XCTAssertEqual(recordBatches.count, 1)
+                for recordBatch in recordBatches {
+                    XCTAssertEqual(recordBatch.length, 3)
+                    XCTAssertEqual(recordBatch.columns.count, 1)
+                    XCTAssertEqual(recordBatch.schema.fields.count, 1)
+                    XCTAssertEqual(recordBatch.schema.fields[0].name, 
"struct1")
+                    XCTAssertEqual(recordBatch.schema.fields[0].type.id, 
.strct)
+                    XCTAssertTrue(recordBatch.schema.fields[0].type is 
ArrowNestedType)
+                    let nestedType = (recordBatch.schema.fields[0].type as? 
ArrowNestedType)!
+                    XCTAssertEqual(nestedType.fields.count, 14)
+                    let columns = recordBatch.columns
+                    XCTAssertEqual(columns[0].nullCount, 1)
+                    XCTAssertNil(columns[0].array.asAny(1))
+                    let structVal =
+                        "\((columns[0].array as? AsString)!.asString(0))"
+                    XCTAssertEqual(structVal, 
"{true,1,2,3,4,5,6,7,8,9.9,10.1,11,12,\(currentDate)}")
+                    let structArray = (recordBatch.columns[0].array as? 
StructArray)!
+                    XCTAssertEqual(structArray.length, 3)
+                    XCTAssertEqual(structArray.arrowFields!.count, 14)
+                    XCTAssertEqual(structArray.arrowFields![0].type.id, 
.boolean)
+                    XCTAssertEqual(structArray.arrowFields![1].type.id, .int8)
+                    XCTAssertEqual(structArray.arrowFields![2].type.id, .int16)
+                    XCTAssertEqual(structArray.arrowFields![3].type.id, .int32)
+                    XCTAssertEqual(structArray.arrowFields![4].type.id, .int64)
+                    XCTAssertEqual(structArray.arrowFields![5].type.id, .uint8)
+                    XCTAssertEqual(structArray.arrowFields![6].type.id, 
.uint16)
+                    XCTAssertEqual(structArray.arrowFields![7].type.id, 
.uint32)
+                    XCTAssertEqual(structArray.arrowFields![8].type.id, 
.uint64)
+                    XCTAssertEqual(structArray.arrowFields![9].type.id, 
.double)
+                    XCTAssertEqual(structArray.arrowFields![10].type.id, 
.float)
+                    XCTAssertEqual(structArray.arrowFields![11].type.id, 
.string)
+                    XCTAssertEqual(structArray.arrowFields![12].type.id, 
.binary)
+                    XCTAssertEqual(structArray.arrowFields![13].type.id, 
.date64)
+                }
+            case.failure(let error):
+                throw error
+            }
+        case .failure(let error):
+            throw error
+        }
+    }
+
     func testBinaryInMemoryToFromStream() throws {
         let dataset = try makeBinaryDataset()
         let writerInfo = ArrowWriter.Info(.recordbatch, schema: dataset.0, 
batches: [dataset.1])
diff --git a/Arrow/Tests/ArrowTests/TableTests.swift 
b/Arrow/Tests/ArrowTests/TableTests.swift
index 8e958cc..dc5cabc 100644
--- a/Arrow/Tests/ArrowTests/TableTests.swift
+++ b/Arrow/Tests/ArrowTests/TableTests.swift
@@ -33,6 +33,55 @@ final class TableTests: XCTestCase {
         XCTAssertEqual(schema.fields[1].isNullable, false)
     }
 
+    func testSchemaNested() {
+        class StructTest {
+            var field0: Bool = false
+            var field1: Int8 = 0
+            var field2: Int16 = 0
+            var field3: Int32 = 0
+            var field4: Int64 = 0
+            var field5: UInt8 = 0
+            var field6: UInt16 = 0
+            var field7: UInt32 = 0
+            var field8: UInt64 = 0
+            var field9: Double = 0
+            var field10: Float = 0
+            var field11: String = ""
+            var field12 = Data()
+            var field13: Date = Date.now
+        }
+
+        let testObj = StructTest()
+        var fields = [ArrowField]()
+        let buildStructType = {() -> ArrowNestedType in
+            let mirror = Mirror(reflecting: testObj)
+            for (property, value) in mirror.children {
+                let arrowType = ArrowType(ArrowType.infoForType(type(of: 
value)))
+                fields.append(ArrowField(property!, type: arrowType, 
isNullable: true))
+            }
+
+            return ArrowNestedType(ArrowType.ArrowStruct, fields: fields)
+        }
+
+        let structType = buildStructType()
+        XCTAssertEqual(structType.id, ArrowTypeId.strct)
+        XCTAssertEqual(structType.fields.count, 14)
+        XCTAssertEqual(structType.fields[0].type.id, ArrowTypeId.boolean)
+        XCTAssertEqual(structType.fields[1].type.id, ArrowTypeId.int8)
+        XCTAssertEqual(structType.fields[2].type.id, ArrowTypeId.int16)
+        XCTAssertEqual(structType.fields[3].type.id, ArrowTypeId.int32)
+        XCTAssertEqual(structType.fields[4].type.id, ArrowTypeId.int64)
+        XCTAssertEqual(structType.fields[5].type.id, ArrowTypeId.uint8)
+        XCTAssertEqual(structType.fields[6].type.id, ArrowTypeId.uint16)
+        XCTAssertEqual(structType.fields[7].type.id, ArrowTypeId.uint32)
+        XCTAssertEqual(structType.fields[8].type.id, ArrowTypeId.uint64)
+        XCTAssertEqual(structType.fields[9].type.id, ArrowTypeId.double)
+        XCTAssertEqual(structType.fields[10].type.id, ArrowTypeId.float)
+        XCTAssertEqual(structType.fields[11].type.id, ArrowTypeId.string)
+        XCTAssertEqual(structType.fields[12].type.id, ArrowTypeId.binary)
+        XCTAssertEqual(structType.fields[13].type.id, ArrowTypeId.date64)
+    }
+
     func testTable() throws {
         let doubleBuilder: NumberArrayBuilder<Double> = try 
ArrowArrayBuilders.loadNumberArrayBuilder()
         doubleBuilder.append(11.11)

Reply via email to