This is an automated email from the ASF dual-hosted git repository.

raulcd pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-swift.git

commit 65814342ecd2703f7f6262bf024e287c09badc60
Author: abandy <[email protected]>
AuthorDate: Thu Jul 25 15:49:52 2024 -0400

    GH-43169: [Swift] Add StructArray to ArrowReader (#43335)
    
    ### Rationale for this change
    Structs have been added for Swift but currently the ArrowReader does not 
support them.  This PR adds the ArrowReader support
    
    ### What changes are included in this PR?
    Adding StructArray to ArrowReader
    
    ### Are these changes tested?
    The next PR for the ArrowWriter will include a test for reading and writing 
Structs.
    
    * GitHub Issue: #43169
    
    Authored-by: Alva Bandy <[email protected]>
    Signed-off-by: Sutou Kouhei <[email protected]>
---
 Arrow/Sources/Arrow/ArrowCImporter.swift    |   3 +-
 Arrow/Sources/Arrow/ArrowReader.swift       | 199 +++++++++++++++++++---------
 Arrow/Sources/Arrow/ArrowReaderHelper.swift |  59 +++++++--
 Arrow/Tests/ArrowTests/ArrayTests.swift     |   2 +-
 4 files changed, 194 insertions(+), 69 deletions(-)

diff --git a/Arrow/Sources/Arrow/ArrowCImporter.swift 
b/Arrow/Sources/Arrow/ArrowCImporter.swift
index f55077e..e65d78d 100644
--- a/Arrow/Sources/Arrow/ArrowCImporter.swift
+++ b/Arrow/Sources/Arrow/ArrowCImporter.swift
@@ -153,7 +153,8 @@ public class ArrowCImporter {
             }
         }
 
-        switch makeArrayHolder(arrowField, buffers: arrowBuffers, nullCount: 
nullCount) {
+        switch makeArrayHolder(arrowField, buffers: arrowBuffers,
+                               nullCount: nullCount, children: nil, rbLength: 
0) {
         case .success(let holder):
             return .success(ImportArrayHolder(holder, cArrayPtr: cArrayPtr))
         case .failure(let err):
diff --git a/Arrow/Sources/Arrow/ArrowReader.swift 
b/Arrow/Sources/Arrow/ArrowReader.swift
index 237f22d..ae187e2 100644
--- a/Arrow/Sources/Arrow/ArrowReader.swift
+++ b/Arrow/Sources/Arrow/ArrowReader.swift
@@ -21,14 +21,46 @@ import Foundation
 let FILEMARKER = "ARROW1"
 let CONTINUATIONMARKER = -1
 
-public class ArrowReader {
-    private struct DataLoadInfo {
+public class ArrowReader { // swiftlint:disable:this type_body_length
+    private class RecordBatchData {
+        let schema: org_apache_arrow_flatbuf_Schema
         let recordBatch: org_apache_arrow_flatbuf_RecordBatch
-        let field: org_apache_arrow_flatbuf_Field
-        let nodeIndex: Int32
-        let bufferIndex: Int32
+        private var fieldIndex: Int32 = 0
+        private var nodeIndex: Int32 = 0
+        private var bufferIndex: Int32 = 0
+        init(_ recordBatch: org_apache_arrow_flatbuf_RecordBatch,
+             schema: org_apache_arrow_flatbuf_Schema) {
+            self.recordBatch = recordBatch
+            self.schema = schema
+        }
+
+        func nextNode() -> org_apache_arrow_flatbuf_FieldNode? {
+            if nodeIndex >= self.recordBatch.nodesCount {return nil}
+            defer {nodeIndex += 1}
+            return self.recordBatch.nodes(at: nodeIndex)
+        }
+
+        func nextBuffer() -> org_apache_arrow_flatbuf_Buffer? {
+            if bufferIndex >= self.recordBatch.buffersCount {return nil}
+            defer {bufferIndex += 1}
+            return self.recordBatch.buffers(at: bufferIndex)
+        }
+
+        func nextField() -> org_apache_arrow_flatbuf_Field? {
+            if fieldIndex >= self.schema.fieldsCount {return nil}
+            defer {fieldIndex += 1}
+            return self.schema.fields(at: fieldIndex)
+        }
+
+        func isDone() -> Bool {
+            return nodeIndex >= self.recordBatch.nodesCount
+        }
+    }
+
+    private struct DataLoadInfo {
         let fileData: Data
         let messageOffset: Int64
+        var batchData: RecordBatchData
     }
 
     public class ArrowReaderResult {
@@ -54,49 +86,104 @@ public class ArrowReader {
         return .success(builder.finish())
     }
 
-    private func loadPrimitiveData(_ loadInfo: DataLoadInfo) -> 
Result<ArrowArrayHolder, ArrowError> {
-        do {
-            let node = loadInfo.recordBatch.nodes(at: loadInfo.nodeIndex)!
-            let nullLength = UInt(ceil(Double(node.length) / 8))
-            try validateBufferIndex(loadInfo.recordBatch, index: 
loadInfo.bufferIndex)
-            let nullBuffer = loadInfo.recordBatch.buffers(at: 
loadInfo.bufferIndex)!
-            let arrowNullBuffer = makeBuffer(nullBuffer, fileData: 
loadInfo.fileData,
-                                             length: nullLength, 
messageOffset: loadInfo.messageOffset)
-            try validateBufferIndex(loadInfo.recordBatch, index: 
loadInfo.bufferIndex + 1)
-            let valueBuffer = loadInfo.recordBatch.buffers(at: 
loadInfo.bufferIndex + 1)!
-            let arrowValueBuffer = makeBuffer(valueBuffer, fileData: 
loadInfo.fileData,
-                                              length: UInt(node.length), 
messageOffset: loadInfo.messageOffset)
-            return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer, 
arrowValueBuffer],
-                                   nullCount: UInt(node.nullCount))
-        } catch let error as ArrowError {
-            return .failure(error)
-        } catch {
-            return .failure(.unknownError("\(error)"))
+    private func loadStructData(_ loadInfo: DataLoadInfo,
+                                field: org_apache_arrow_flatbuf_Field)
+    -> Result<ArrowArrayHolder, ArrowError> {
+        guard let node = loadInfo.batchData.nextNode() else {
+            return .failure(.invalid("Node not found"))
+        }
+
+        guard let nullBuffer = loadInfo.batchData.nextBuffer() else {
+            return .failure(.invalid("Null buffer not found"))
+        }
+
+        let nullLength = UInt(ceil(Double(node.length) / 8))
+        let arrowNullBuffer = makeBuffer(nullBuffer, fileData: 
loadInfo.fileData,
+                                         length: nullLength, messageOffset: 
loadInfo.messageOffset)
+        var children = [ArrowData]()
+        for index in 0..<field.childrenCount {
+            let childField = field.children(at: index)!
+            switch loadField(loadInfo, field: childField) {
+            case .success(let holder):
+                children.append(holder.array.arrowData)
+            case .failure(let error):
+                return .failure(error)
+            }
         }
+
+        return makeArrayHolder(field, buffers: [arrowNullBuffer],
+                               nullCount: UInt(node.nullCount), children: 
children,
+                               rbLength: 
UInt(loadInfo.batchData.recordBatch.length))
     }
 
-    private func loadVariableData(_ loadInfo: DataLoadInfo) -> 
Result<ArrowArrayHolder, ArrowError> {
-        let node = loadInfo.recordBatch.nodes(at: loadInfo.nodeIndex)!
-        do {
-            let nullLength = UInt(ceil(Double(node.length) / 8))
-            try validateBufferIndex(loadInfo.recordBatch, index: 
loadInfo.bufferIndex)
-            let nullBuffer = loadInfo.recordBatch.buffers(at: 
loadInfo.bufferIndex)!
-            let arrowNullBuffer = makeBuffer(nullBuffer, fileData: 
loadInfo.fileData,
-                                             length: nullLength, 
messageOffset: loadInfo.messageOffset)
-            try validateBufferIndex(loadInfo.recordBatch, index: 
loadInfo.bufferIndex + 1)
-            let offsetBuffer = loadInfo.recordBatch.buffers(at: 
loadInfo.bufferIndex + 1)!
-            let arrowOffsetBuffer = makeBuffer(offsetBuffer, fileData: 
loadInfo.fileData,
-                                               length: UInt(node.length), 
messageOffset: loadInfo.messageOffset)
-            try validateBufferIndex(loadInfo.recordBatch, index: 
loadInfo.bufferIndex + 2)
-            let valueBuffer = loadInfo.recordBatch.buffers(at: 
loadInfo.bufferIndex + 2)!
-            let arrowValueBuffer = makeBuffer(valueBuffer, fileData: 
loadInfo.fileData,
-                                              length: UInt(node.length), 
messageOffset: loadInfo.messageOffset)
-            return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer, 
arrowOffsetBuffer, arrowValueBuffer],
-                                   nullCount: UInt(node.nullCount))
-        } catch let error as ArrowError {
-            return .failure(error)
-        } catch {
-            return .failure(.unknownError("\(error)"))
+    private func loadPrimitiveData(
+        _ loadInfo: DataLoadInfo,
+        field: org_apache_arrow_flatbuf_Field)
+    -> Result<ArrowArrayHolder, ArrowError> {
+        guard let node = loadInfo.batchData.nextNode() else {
+            return .failure(.invalid("Node not found"))
+        }
+
+        guard let nullBuffer = loadInfo.batchData.nextBuffer() else {
+            return .failure(.invalid("Null buffer not found"))
+        }
+
+        guard let valueBuffer = loadInfo.batchData.nextBuffer() else {
+            return .failure(.invalid("Value buffer not found"))
+        }
+
+        let nullLength = UInt(ceil(Double(node.length) / 8))
+        let arrowNullBuffer = makeBuffer(nullBuffer, fileData: 
loadInfo.fileData,
+                                         length: nullLength, messageOffset: 
loadInfo.messageOffset)
+        let arrowValueBuffer = makeBuffer(valueBuffer, fileData: 
loadInfo.fileData,
+                                          length: UInt(node.length), 
messageOffset: loadInfo.messageOffset)
+        return makeArrayHolder(field, buffers: [arrowNullBuffer, 
arrowValueBuffer],
+                               nullCount: UInt(node.nullCount), children: nil,
+                               rbLength: 
UInt(loadInfo.batchData.recordBatch.length))
+    }
+
+    private func loadVariableData(
+        _ loadInfo: DataLoadInfo,
+        field: org_apache_arrow_flatbuf_Field)
+    -> Result<ArrowArrayHolder, ArrowError> {
+        guard let node = loadInfo.batchData.nextNode() else {
+            return .failure(.invalid("Node not found"))
+        }
+
+        guard let nullBuffer = loadInfo.batchData.nextBuffer() else {
+            return .failure(.invalid("Null buffer not found"))
+        }
+
+        guard let offsetBuffer = loadInfo.batchData.nextBuffer() else {
+            return .failure(.invalid("Offset buffer not found"))
+        }
+
+        guard let valueBuffer = loadInfo.batchData.nextBuffer() else {
+            return .failure(.invalid("Value buffer not found"))
+        }
+
+        let nullLength = UInt(ceil(Double(node.length) / 8))
+        let arrowNullBuffer = makeBuffer(nullBuffer, fileData: 
loadInfo.fileData,
+                                         length: nullLength, messageOffset: 
loadInfo.messageOffset)
+        let arrowOffsetBuffer = makeBuffer(offsetBuffer, fileData: 
loadInfo.fileData,
+                                           length: UInt(node.length), 
messageOffset: loadInfo.messageOffset)
+        let arrowValueBuffer = makeBuffer(valueBuffer, fileData: 
loadInfo.fileData,
+                                          length: UInt(node.length), 
messageOffset: loadInfo.messageOffset)
+        return makeArrayHolder(field, buffers: [arrowNullBuffer, 
arrowOffsetBuffer, arrowValueBuffer],
+                               nullCount: UInt(node.nullCount), children: nil,
+                               rbLength: 
UInt(loadInfo.batchData.recordBatch.length))
+    }
+
+    private func loadField(
+        _ loadInfo: DataLoadInfo,
+        field: org_apache_arrow_flatbuf_Field)
+    -> Result<ArrowArrayHolder, ArrowError> {
+        if isNestedType(field.typeType) {
+            return loadStructData(loadInfo, field: field)
+        } else if isFixedPrimitive(field.typeType) {
+            return loadPrimitiveData(loadInfo, field: field)
+        } else {
+            return loadVariableData(loadInfo, field: field)
         }
     }
 
@@ -107,23 +194,17 @@ public class ArrowReader {
         data: Data,
         messageEndOffset: Int64
     ) -> Result<RecordBatch, ArrowError> {
-        let nodesCount = recordBatch.nodesCount
-        var bufferIndex: Int32 = 0
         var columns: [ArrowArrayHolder] = []
-        for nodeIndex in 0 ..< nodesCount {
-            let field = schema.fields(at: nodeIndex)!
-            let loadInfo = DataLoadInfo(recordBatch: recordBatch, field: field,
-                                        nodeIndex: nodeIndex, bufferIndex: 
bufferIndex,
-                                        fileData: data, messageOffset: 
messageEndOffset)
-            var result: Result<ArrowArrayHolder, ArrowError>
-            if isFixedPrimitive(field.typeType) {
-                result = loadPrimitiveData(loadInfo)
-                bufferIndex += 2
-            } else {
-                result = loadVariableData(loadInfo)
-                bufferIndex += 3
+        let batchData = RecordBatchData(recordBatch, schema: schema)
+        let loadInfo = DataLoadInfo(fileData: data,
+                                    messageOffset: messageEndOffset,
+                                    batchData: batchData)
+        while !batchData.isDone() {
+            guard let field = batchData.nextField() else {
+                return .failure(.invalid("Field not found"))
             }
 
+            let result = loadField(loadInfo, field: field)
             switch result {
             case .success(let holder):
                 columns.append(holder)
diff --git a/Arrow/Sources/Arrow/ArrowReaderHelper.swift 
b/Arrow/Sources/Arrow/ArrowReaderHelper.swift
index 22c0672..48c6fd8 100644
--- a/Arrow/Sources/Arrow/ArrowReaderHelper.swift
+++ b/Arrow/Sources/Arrow/ArrowReaderHelper.swift
@@ -117,19 +117,42 @@ private func makeFixedHolder<T>(
     }
 }
 
+ func makeStructHolder(
+    _ field: ArrowField,
+    buffers: [ArrowBuffer],
+    nullCount: UInt,
+    children: [ArrowData],
+    rbLength: UInt
+) -> Result<ArrowArrayHolder, ArrowError> {
+    do {
+        let arrowData = try ArrowData(field.type,
+                                      buffers: buffers, children: children,
+                                      nullCount: nullCount, length: rbLength)
+        return .success(ArrowArrayHolderImpl(try StructArray(arrowData)))
+    } catch let error as ArrowError {
+        return .failure(error)
+    } catch {
+        return .failure(.unknownError("\(error)"))
+    }
+}
+
 func makeArrayHolder(
     _ field: org_apache_arrow_flatbuf_Field,
     buffers: [ArrowBuffer],
-    nullCount: UInt
+    nullCount: UInt,
+    children: [ArrowData]?,
+    rbLength: UInt
 ) -> Result<ArrowArrayHolder, ArrowError> {
     let arrowField = fromProto(field: field)
-    return makeArrayHolder(arrowField, buffers: buffers, nullCount: nullCount)
+    return makeArrayHolder(arrowField, buffers: buffers, nullCount: nullCount, 
children: children, rbLength: rbLength)
 }
 
 func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity
     _ field: ArrowField,
     buffers: [ArrowBuffer],
-    nullCount: UInt
+    nullCount: UInt,
+    children: [ArrowData]?,
+    rbLength: UInt
 ) -> Result<ArrowArrayHolder, ArrowError> {
     let typeId = field.type.id
     switch typeId {
@@ -159,12 +182,12 @@ func makeArrayHolder( // swiftlint:disable:this 
cyclomatic_complexity
         return makeStringHolder(buffers, nullCount: nullCount)
     case .binary:
         return makeBinaryHolder(buffers, nullCount: nullCount)
-    case .date32:
+    case .date32, .date64:
         return makeDateHolder(field, buffers: buffers, nullCount: nullCount)
-    case .time32:
-        return makeTimeHolder(field, buffers: buffers, nullCount: nullCount)
-    case .time64:
+    case .time32, .time64:
         return makeTimeHolder(field, buffers: buffers, nullCount: nullCount)
+    case .strct:
+        return makeStructHolder(field, buffers: buffers, nullCount: nullCount, 
children: children!, rbLength: rbLength)
     default:
         return .failure(.unknownType("Type \(typeId) currently not supported"))
     }
@@ -187,7 +210,16 @@ func isFixedPrimitive(_ type: 
org_apache_arrow_flatbuf_Type_) -> Bool {
     }
 }
 
-func findArrowType( // swiftlint:disable:this cyclomatic_complexity
+func isNestedType(_ type: org_apache_arrow_flatbuf_Type_) -> Bool {
+    switch type {
+    case .struct_:
+        return true
+    default:
+        return false
+    }
+}
+
+func findArrowType( // swiftlint:disable:this cyclomatic_complexity 
function_body_length
     _ field: org_apache_arrow_flatbuf_Field) -> ArrowType {
     let type = field.typeType
     switch type {
@@ -229,6 +261,17 @@ func findArrowType( // swiftlint:disable:this 
cyclomatic_complexity
         }
 
         return ArrowTypeTime64(timeType.unit == .microsecond ? .microseconds : 
.nanoseconds)
+    case .struct_:
+        _ = field.type(type: org_apache_arrow_flatbuf_Struct_.self)!
+        var fields = [ArrowField]()
+        for index in 0..<field.childrenCount {
+            let childField = field.children(at: index)!
+            let childType = findArrowType(childField)
+            fields.append(
+                ArrowField(childField.name ?? "", type: childType, isNullable: 
childField.nullable))
+        }
+
+        return ArrowNestedType(ArrowType.ArrowStruct, fields: fields)
     default:
         return ArrowType(ArrowType.ArrowUnknown)
     }
diff --git a/Arrow/Tests/ArrowTests/ArrayTests.swift 
b/Arrow/Tests/ArrowTests/ArrayTests.swift
index bfd7492..d793aa1 100644
--- a/Arrow/Tests/ArrowTests/ArrayTests.swift
+++ b/Arrow/Tests/ArrowTests/ArrayTests.swift
@@ -279,7 +279,7 @@ final class ArrayTests: XCTestCase { // 
swiftlint:disable:this type_body_length
                        ArrowBuffer(length: 0, capacity: 0,
                                rawPointer: 
UnsafeMutableRawPointer.allocate(byteCount: 0, alignment: .zero))]
         let field = ArrowField("", type: checkType, isNullable: true)
-        switch makeArrayHolder(field, buffers: buffers, nullCount: 0) {
+        switch makeArrayHolder(field, buffers: buffers, nullCount: 0, 
children: nil, rbLength: 0) {
         case .success(let holder):
             XCTAssertEqual(holder.type.id, checkType.id)
         case .failure(let err):

Reply via email to