This is an automated email from the ASF dual-hosted git repository.
kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 601be7687b GH-42020: [Swift] Add Arrow decoding implementation for
Swift Codable (#42023)
601be7687b is described below
commit 601be7687ba89f711b876397746b5f49503c0871
Author: abandy <[email protected]>
AuthorDate: Sat Jun 8 17:25:28 2024 -0400
GH-42020: [Swift] Add Arrow decoding implementation for Swift Codable
(#42023)
### Rationale for this change
This change implements decode for the Arrow Swift Codable implementation.
This allows the data in a RecordBatch to be copied to properties in a
struct/class.
The PR is a bit longer than desired but all three container types are
required in order to implement the Decoder protocol.
### What changes are included in this PR?
The ArrowDecoder class is included in this PR along with a class for each
container type (keyed, unkeyed, and single). Most of the logic is encapsulated
in the ArrowDecoder with minimal logic in each container class (Most of the
methods in the container classes are a single line that calls the ArrowDecoder
doDecode methods)
### Are these changes tested?
Yes, a test has been added to test the three types of containers provided
by the decoder.
* GitHub Issue: #42020
Authored-by: Alva Bandy <[email protected]>
Signed-off-by: Sutou Kouhei <[email protected]>
---
swift/Arrow/Sources/Arrow/ArrowDecoder.swift | 347 ++++++++++++++++++++++++
swift/Arrow/Tests/ArrowTests/CodableTests.swift | 170 ++++++++++++
2 files changed, 517 insertions(+)
diff --git a/swift/Arrow/Sources/Arrow/ArrowDecoder.swift
b/swift/Arrow/Sources/Arrow/ArrowDecoder.swift
new file mode 100644
index 0000000000..7e0c69b1e7
--- /dev/null
+++ b/swift/Arrow/Sources/Arrow/ArrowDecoder.swift
@@ -0,0 +1,347 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+import Foundation
+
+public class ArrowDecoder: Decoder {
+ var rbIndex: UInt = 0
+ public var codingPath: [CodingKey] = []
+ public var userInfo: [CodingUserInfoKey: Any] = [:]
+ public let rb: RecordBatch
+ public let nameToCol: [String: ArrowArrayHolder]
+ public let columns: [ArrowArrayHolder]
+ public init(_ decoder: ArrowDecoder) {
+ self.userInfo = decoder.userInfo
+ self.codingPath = decoder.codingPath
+ self.rb = decoder.rb
+ self.columns = decoder.columns
+ self.nameToCol = decoder.nameToCol
+ self.rbIndex = decoder.rbIndex
+ }
+
+ public init(_ rb: RecordBatch) {
+ self.rb = rb
+ var colMapping = [String: ArrowArrayHolder]()
+ var columns = [ArrowArrayHolder]()
+ for index in 0..<self.rb.schema.fields.count {
+ let field = self.rb.schema.fields[index]
+ columns.append(self.rb.column(index))
+ colMapping[field.name] = self.rb.column(index)
+ }
+
+ self.columns = columns
+ self.nameToCol = colMapping
+ }
+
+ public func decode<T: Decodable>(_ type: T.Type) throws -> [T] {
+ var output = [T]()
+ for index in 0..<rb.length {
+ self.rbIndex = index
+ output.append(try type.init(from: self))
+ }
+
+ return output
+ }
+
+ public func container<Key>(keyedBy type: Key.Type
+ ) -> KeyedDecodingContainer<Key> where Key: CodingKey {
+ let container = ArrowKeyedDecoding<Key>(self, codingPath: codingPath)
+ return KeyedDecodingContainer(container)
+ }
+
+ public func unkeyedContainer() -> UnkeyedDecodingContainer {
+ return ArrowUnkeyedDecoding(self, codingPath: codingPath)
+ }
+
+ public func singleValueContainer() -> SingleValueDecodingContainer {
+ return ArrowSingleValueDecoding(self, codingPath: codingPath)
+ }
+
+ func getCol(_ name: String) throws -> AnyArray {
+ guard let col = self.nameToCol[name] else {
+ throw ArrowError.invalid("Column for key \"\(name)\" not found")
+ }
+
+ guard let anyArray = col.array as? AnyArray else {
+ throw ArrowError.invalid("Unable to convert array to AnyArray")
+ }
+
+ return anyArray
+ }
+
+ func getCol(_ index: Int) throws -> AnyArray {
+ if index >= self.columns.count {
+ throw ArrowError.outOfBounds(index: Int64(index))
+ }
+
+ guard let anyArray = self.columns[index].array as? AnyArray else {
+ throw ArrowError.invalid("Unable to convert array to AnyArray")
+ }
+
+ return anyArray
+ }
+
+ func doDecode<T>(_ key: CodingKey) throws -> T? {
+ let array: AnyArray = try self.getCol(key.stringValue)
+ return array.asAny(self.rbIndex) as? T
+ }
+
+ func doDecode<T>(_ col: Int) throws -> T? {
+ let array: AnyArray = try self.getCol(col)
+ return array.asAny(self.rbIndex) as? T
+ }
+}
+
+private struct ArrowUnkeyedDecoding: UnkeyedDecodingContainer {
+ var codingPath: [CodingKey]
+ var count: Int? = 0
+ var isAtEnd: Bool = false
+ var currentIndex: Int = 0
+ let decoder: ArrowDecoder
+
+ init(_ decoder: ArrowDecoder, codingPath: [CodingKey]) {
+ self.decoder = decoder
+ self.codingPath = codingPath
+ self.count = self.decoder.columns.count
+ }
+
+ mutating func increment() {
+ self.currentIndex += 1
+ self.isAtEnd = self.currentIndex >= self.count!
+ }
+
+ mutating func decodeNil() throws -> Bool {
+ defer {increment()}
+ return try self.decoder.doDecode(self.currentIndex) == nil
+ }
+
+ mutating func decode<T>(_ type: T.Type) throws -> T where T: Decodable {
+ if type == Int8.self || type == Int16.self ||
+ type == Int32.self || type == Int64.self ||
+ type == UInt8.self || type == UInt16.self ||
+ type == UInt32.self || type == UInt64.self ||
+ type == String.self || type == Double.self ||
+ type == Float.self || type == Date.self {
+ defer {increment()}
+ return try self.decoder.doDecode(self.currentIndex)!
+ } else {
+ throw ArrowError.invalid("Type \(type) is currently not supported")
+ }
+ }
+
+ func nestedContainer<NestedKey>(
+ keyedBy type: NestedKey.Type
+ ) throws -> KeyedDecodingContainer<NestedKey> where NestedKey: CodingKey {
+ throw ArrowError.invalid("Nested decoding is currently not supported.")
+ }
+
+ func nestedUnkeyedContainer() throws -> UnkeyedDecodingContainer {
+ throw ArrowError.invalid("Nested decoding is currently not supported.")
+ }
+
+ func superDecoder() throws -> Decoder {
+ throw ArrowError.invalid("super decoding is currently not supported.")
+ }
+}
+
+private struct ArrowKeyedDecoding<Key: CodingKey>:
KeyedDecodingContainerProtocol {
+ var codingPath = [CodingKey]()
+ var allKeys = [Key]()
+ let decoder: ArrowDecoder
+
+ init(_ decoder: ArrowDecoder, codingPath: [CodingKey]) {
+ self.decoder = decoder
+ self.codingPath = codingPath
+ }
+
+ func contains(_ key: Key) -> Bool {
+ return self.decoder.nameToCol.keys.contains(key.stringValue)
+ }
+
+ func decodeNil(forKey key: Key) throws -> Bool {
+ return try self.decoder.doDecode(key) == nil
+ }
+
+ func decode(_ type: Bool.Type, forKey key: Key) throws -> Bool {
+ return try self.decoder.doDecode(key)!
+ }
+
+ func decode(_ type: String.Type, forKey key: Key) throws -> String {
+ return try self.decoder.doDecode(key)!
+ }
+
+ func decode(_ type: Double.Type, forKey key: Key) throws -> Double {
+ return try self.decoder.doDecode(key)!
+ }
+
+ func decode(_ type: Float.Type, forKey key: Key) throws -> Float {
+ return try self.decoder.doDecode(key)!
+ }
+
+ func decode(_ type: Int.Type, forKey key: Key) throws -> Int {
+ throw ArrowError.invalid(
+ "Int type is not supported (please use Int8, Int16, Int32 or
Int64)")
+ }
+
+ func decode(_ type: Int8.Type, forKey key: Key) throws -> Int8 {
+ return try self.decoder.doDecode(key)!
+ }
+
+ func decode(_ type: Int16.Type, forKey key: Key) throws -> Int16 {
+ return try self.decoder.doDecode(key)!
+ }
+
+ func decode(_ type: Int32.Type, forKey key: Key) throws -> Int32 {
+ return try self.decoder.doDecode(key)!
+ }
+
+ func decode(_ type: Int64.Type, forKey key: Key) throws -> Int64 {
+ return try self.decoder.doDecode(key)!
+ }
+
+ func decode(_ type: UInt.Type, forKey key: Key) throws -> UInt {
+ throw ArrowError.invalid(
+ "UInt type is not supported (please use UInt8, UInt16, UInt32 or
UInt64)")
+ }
+
+ func decode(_ type: UInt8.Type, forKey key: Key) throws -> UInt8 {
+ return try self.decoder.doDecode(key)!
+ }
+
+ func decode(_ type: UInt16.Type, forKey key: Key) throws -> UInt16 {
+ return try self.decoder.doDecode(key)!
+ }
+
+ func decode(_ type: UInt32.Type, forKey key: Key) throws -> UInt32 {
+ return try self.decoder.doDecode(key)!
+ }
+
+ func decode(_ type: UInt64.Type, forKey key: Key) throws -> UInt64 {
+ return try self.decoder.doDecode(key)!
+ }
+
+ func decode<T>(_ type: T.Type, forKey key: Key) throws -> T where T:
Decodable {
+ if type == Date.self {
+ return try self.decoder.doDecode(key)!
+ } else {
+ throw ArrowError.invalid("Type \(type) is currently not supported")
+ }
+ }
+
+ func nestedContainer<NestedKey>(
+ keyedBy type: NestedKey.Type,
+ forKey key: Key
+ ) throws -> KeyedDecodingContainer<NestedKey> where NestedKey: CodingKey {
+ throw ArrowError.invalid("Nested decoding is currently not supported.")
+ }
+
+ func nestedUnkeyedContainer(forKey key: Key) throws ->
UnkeyedDecodingContainer {
+ throw ArrowError.invalid("Nested decoding is currently not supported.")
+ }
+
+ func superDecoder() throws -> Decoder {
+ throw ArrowError.invalid("super decoding is currently not supported.")
+ }
+
+ func superDecoder(forKey key: Key) throws -> Decoder {
+ throw ArrowError.invalid("super decoding is currently not supported.")
+ }
+}
+
+private struct ArrowSingleValueDecoding: SingleValueDecodingContainer {
+ var codingPath = [CodingKey]()
+ let decoder: ArrowDecoder
+
+ init(_ decoder: ArrowDecoder, codingPath: [CodingKey]) {
+ self.decoder = decoder
+ self.codingPath = codingPath
+ }
+
+ func decodeNil() -> Bool {
+ do {
+ return try self.decoder.doDecode(0) == nil
+ } catch {
+ return false
+ }
+ }
+
+ func decode(_ type: Bool.Type) throws -> Bool {
+ return try self.decoder.doDecode(0)!
+ }
+
+ func decode(_ type: String.Type) throws -> String {
+ return try self.decoder.doDecode(0)!
+ }
+
+ func decode(_ type: Double.Type) throws -> Double {
+ return try self.decoder.doDecode(0)!
+ }
+
+ func decode(_ type: Float.Type) throws -> Float {
+ return try self.decoder.doDecode(0)!
+ }
+
+ func decode(_ type: Int.Type) throws -> Int {
+ throw ArrowError.invalid(
+ "Int type is not supported (please use Int8, Int16, Int32 or
Int64)")
+ }
+
+ func decode(_ type: Int8.Type) throws -> Int8 {
+ return try self.decoder.doDecode(0)!
+ }
+
+ func decode(_ type: Int16.Type) throws -> Int16 {
+ return try self.decoder.doDecode(0)!
+ }
+
+ func decode(_ type: Int32.Type) throws -> Int32 {
+ return try self.decoder.doDecode(0)!
+ }
+
+ func decode(_ type: Int64.Type) throws -> Int64 {
+ return try self.decoder.doDecode(0)!
+ }
+
+ func decode(_ type: UInt.Type) throws -> UInt {
+ throw ArrowError.invalid(
+ "UInt type is not supported (please use UInt8, UInt16, UInt32 or
UInt64)")
+ }
+
+ func decode(_ type: UInt8.Type) throws -> UInt8 {
+ return try self.decoder.doDecode(0)!
+ }
+
+ func decode(_ type: UInt16.Type) throws -> UInt16 {
+ return try self.decoder.doDecode(0)!
+ }
+
+ func decode(_ type: UInt32.Type) throws -> UInt32 {
+ return try self.decoder.doDecode(0)!
+ }
+
+ func decode(_ type: UInt64.Type) throws -> UInt64 {
+ return try self.decoder.doDecode(0)!
+ }
+
+ func decode<T>(_ type: T.Type) throws -> T where T: Decodable {
+ if type == Date.self {
+ return try self.decoder.doDecode(0)!
+ } else {
+ throw ArrowError.invalid("Type \(type) is currently not supported")
+ }
+ }
+}
diff --git a/swift/Arrow/Tests/ArrowTests/CodableTests.swift
b/swift/Arrow/Tests/ArrowTests/CodableTests.swift
new file mode 100644
index 0000000000..e7359467ae
--- /dev/null
+++ b/swift/Arrow/Tests/ArrowTests/CodableTests.swift
@@ -0,0 +1,170 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+import XCTest
+@testable import Arrow
+
+final class CodableTests: XCTestCase {
+ public class TestClass: Codable {
+ public var propBool: Bool
+ public var propInt8: Int8
+ public var propInt16: Int16
+ public var propInt32: Int32
+ public var propInt64: Int64
+ public var propUInt8: UInt8
+ public var propUInt16: UInt16
+ public var propUInt32: UInt32
+ public var propUInt64: UInt64
+ public var propFloat: Float
+ public var propDouble: Double
+ public var propString: String
+ public var propDate: Date
+
+ public required init() {
+ self.propBool = false
+ self.propInt8 = 1
+ self.propInt16 = 2
+ self.propInt32 = 3
+ self.propInt64 = 4
+ self.propUInt8 = 5
+ self.propUInt16 = 6
+ self.propUInt32 = 7
+ self.propUInt64 = 8
+ self.propFloat = 9
+ self.propDouble = 10
+ self.propString = "11"
+ self.propDate = Date.now
+ }
+ }
+
+ func testArrowKeyedDecoder() throws { // swiftlint:disable:this
function_body_length
+ let date1 = Date(timeIntervalSinceReferenceDate: 86400 * 5000 + 352)
+
+ let boolBuilder = try ArrowArrayBuilders.loadBoolArrayBuilder()
+ let int8Builder: NumberArrayBuilder<Int8> = try
ArrowArrayBuilders.loadNumberArrayBuilder()
+ let int16Builder: NumberArrayBuilder<Int16> = try
ArrowArrayBuilders.loadNumberArrayBuilder()
+ let int32Builder: NumberArrayBuilder<Int32> = try
ArrowArrayBuilders.loadNumberArrayBuilder()
+ let int64Builder: NumberArrayBuilder<Int64> = try
ArrowArrayBuilders.loadNumberArrayBuilder()
+ let uint8Builder: NumberArrayBuilder<UInt8> = try
ArrowArrayBuilders.loadNumberArrayBuilder()
+ let uint16Builder: NumberArrayBuilder<UInt16> = try
ArrowArrayBuilders.loadNumberArrayBuilder()
+ let uint32Builder: NumberArrayBuilder<UInt32> = try
ArrowArrayBuilders.loadNumberArrayBuilder()
+ let uint64Builder: NumberArrayBuilder<UInt64> = try
ArrowArrayBuilders.loadNumberArrayBuilder()
+ let floatBuilder: NumberArrayBuilder<Float> = try
ArrowArrayBuilders.loadNumberArrayBuilder()
+ let doubleBuilder: NumberArrayBuilder<Double> = try
ArrowArrayBuilders.loadNumberArrayBuilder()
+ let stringBuilder = try ArrowArrayBuilders.loadStringArrayBuilder()
+ let dateBuilder = try ArrowArrayBuilders.loadDate64ArrayBuilder()
+
+ boolBuilder.append(false, true, false)
+ int8Builder.append(10, 11, 12)
+ int16Builder.append(20, 21, 22)
+ int32Builder.append(30, 31, 32)
+ int64Builder.append(40, 41, 42)
+ uint8Builder.append(50, 51, 52)
+ uint16Builder.append(60, 61, 62)
+ uint32Builder.append(70, 71, 72)
+ uint64Builder.append(80, 81, 82)
+ floatBuilder.append(90.1, 91.1, 92.1)
+ doubleBuilder.append(100.1, 101.1, 102.1)
+ stringBuilder.append("test0", "test1", "test2")
+ dateBuilder.append(date1, date1, date1)
+ let result = RecordBatch.Builder()
+ .addColumn("propBool", arrowArray: try boolBuilder.toHolder())
+ .addColumn("propInt8", arrowArray: try int8Builder.toHolder())
+ .addColumn("propInt16", arrowArray: try int16Builder.toHolder())
+ .addColumn("propInt32", arrowArray: try int32Builder.toHolder())
+ .addColumn("propInt64", arrowArray: try int64Builder.toHolder())
+ .addColumn("propUInt8", arrowArray: try uint8Builder.toHolder())
+ .addColumn("propUInt16", arrowArray: try uint16Builder.toHolder())
+ .addColumn("propUInt32", arrowArray: try uint32Builder.toHolder())
+ .addColumn("propUInt64", arrowArray: try uint64Builder.toHolder())
+ .addColumn("propFloat", arrowArray: try floatBuilder.toHolder())
+ .addColumn("propDouble", arrowArray: try doubleBuilder.toHolder())
+ .addColumn("propString", arrowArray: try stringBuilder.toHolder())
+ .addColumn("propDate", arrowArray: try dateBuilder.toHolder())
+ .finish()
+ switch result {
+ case .success(let rb):
+ let decoder = ArrowDecoder(rb)
+ var testClasses = try decoder.decode(TestClass.self)
+ for index in 0..<testClasses.count {
+ let testClass = testClasses[index]
+ var col = 0
+ XCTAssertEqual(testClass.propBool, index % 2 == 0 ? false :
true)
+ XCTAssertEqual(testClass.propInt8, Int8(index + 10))
+ XCTAssertEqual(testClass.propInt16, Int16(index + 20))
+ XCTAssertEqual(testClass.propInt32, Int32(index + 30))
+ XCTAssertEqual(testClass.propInt64, Int64(index + 40))
+ XCTAssertEqual(testClass.propUInt8, UInt8(index + 50))
+ XCTAssertEqual(testClass.propUInt16, UInt16(index + 60))
+ XCTAssertEqual(testClass.propUInt32, UInt32(index + 70))
+ XCTAssertEqual(testClass.propUInt64, UInt64(index + 80))
+ XCTAssertEqual(testClass.propFloat, Float(index) + 90.1)
+ XCTAssertEqual(testClass.propDouble, Double(index) + 100.1)
+ XCTAssertEqual(testClass.propString, "test\(index)")
+ XCTAssertEqual(testClass.propDate, date1)
+ }
+ case .failure(let err):
+ throw err
+ }
+ }
+
+ func testArrowSingleDecoder() throws {
+ let int8Builder: NumberArrayBuilder<Int8> = try
ArrowArrayBuilders.loadNumberArrayBuilder()
+ int8Builder.append(10, 11, 12, nil)
+ let result = RecordBatch.Builder()
+ .addColumn("propInt8", arrowArray: try int8Builder.toHolder())
+ .finish()
+ switch result {
+ case .success(let rb):
+ let decoder = ArrowDecoder(rb)
+ let testData = try decoder.decode(Int8?.self)
+ for index in 0..<testData.count {
+ let val: Int8? = testData[index]
+ if val != nil {
+ XCTAssertEqual(val!, Int8(index + 10))
+ }
+ }
+ case .failure(let err):
+ throw err
+ }
+ }
+
+ func testArrowUnkeyedDecoder() throws {
+ let int8Builder: NumberArrayBuilder<Int8> = try
ArrowArrayBuilders.loadNumberArrayBuilder()
+ let stringBuilder = try ArrowArrayBuilders.loadStringArrayBuilder()
+ int8Builder.append(10, 11, 12)
+ stringBuilder.append("test0", "test1", "test2")
+ let result = RecordBatch.Builder()
+ .addColumn("propInt8", arrowArray: try int8Builder.toHolder())
+ .addColumn("propString", arrowArray: try stringBuilder.toHolder())
+ .finish()
+ switch result {
+ case .success(let rb):
+ let decoder = ArrowDecoder(rb)
+ let testData = try decoder.decode([Int8: String].self)
+ var index: Int8 = 0
+ for data in testData {
+ let str = data[10 + index]
+ XCTAssertEqual(str, "test\(index)")
+ index += 1
+ }
+ case .failure(let err):
+ throw err
+ }
+ }
+
+}