This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new d28078d80e GH-42245: [Swift] Ensure map behavior is the same for all 
key types (#42246)
d28078d80e is described below

commit d28078d80e7fe7ff93e4b8a9a331ec4ba4648bf1
Author: abandy <[email protected]>
AuthorDate: Fri Jun 21 21:03:01 2024 -0400

    GH-42245: [Swift] Ensure map behavior is the same for all key types (#42246)
    
    ### Rationale for this change
    Behavior of decoding a map currently changes based on the key type (String 
or not String)
    
    ### What changes are included in this PR?
    Added method for handling map to ensure that all maps are decoded the same 
way.
    
    ### Are these changes tested?
    Yes
    
    * GitHub Issue: #42245
    
    Authored-by: Alva Bandy <[email protected]>
    Signed-off-by: Sutou Kouhei <[email protected]>
---
 swift/Arrow/Sources/Arrow/ArrowDecoder.swift    | 58 +++++++++++++++----------
 swift/Arrow/Tests/ArrowTests/CodableTests.swift | 43 ++++++++++--------
 2 files changed, 61 insertions(+), 40 deletions(-)

diff --git a/swift/Arrow/Sources/Arrow/ArrowDecoder.swift 
b/swift/Arrow/Sources/Arrow/ArrowDecoder.swift
index 518b4e9c32..7e684f360a 100644
--- a/swift/Arrow/Sources/Arrow/ArrowDecoder.swift
+++ b/swift/Arrow/Sources/Arrow/ArrowDecoder.swift
@@ -19,6 +19,7 @@ import Foundation
 
 public class ArrowDecoder: Decoder {
     var rbIndex: UInt = 0
+    var singleRBCol: Int = 0
     public var codingPath: [CodingKey] = []
     public var userInfo: [CodingUserInfoKey: Any] = [:]
     public let rb: RecordBatch
@@ -47,6 +48,25 @@ public class ArrowDecoder: Decoder {
         self.nameToCol = colMapping
     }
 
+    public func decode<T: Decodable, U: Decodable>(_ type: [T: U].Type) throws 
-> [T: U] {
+        var output = [T: U]()
+        if rb.columnCount != 2 {
+            throw ArrowError.invalid("RecordBatch column count of 2 is 
required to decode to map")
+        }
+
+        for index in 0..<rb.length {
+            self.rbIndex = index
+            self.singleRBCol = 0
+            let key = try T.init(from: self)
+            self.singleRBCol = 1
+            let value = try U.init(from: self)
+            output[key] = value
+        }
+
+        self.singleRBCol = 0
+        return output
+    }
+
     public func decode<T: Decodable>(_ type: T.Type) throws -> [T] {
         var output = [T]()
         for index in 0..<rb.length {
@@ -252,7 +272,7 @@ private struct ArrowKeyedDecoding<Key: CodingKey>: 
KeyedDecodingContainerProtoco
     }
 
     func decode<T>(_ type: T.Type, forKey key: Key) throws -> T where T: 
Decodable {
-        if type == Date.self {
+        if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self {
             return try self.decoder.doDecode(key)!
         } else {
             throw ArrowError.invalid("Type \(type) is currently not supported")
@@ -290,26 +310,26 @@ private struct ArrowSingleValueDecoding: 
SingleValueDecodingContainer {
 
     func decodeNil() -> Bool {
         do {
-            return try self.decoder.isNull(0)
+            return try self.decoder.isNull(self.decoder.singleRBCol)
         } catch {
             return false
         }
     }
 
     func decode(_ type: Bool.Type) throws -> Bool {
-        return try self.decoder.doDecode(0)!
+        return try self.decoder.doDecode(self.decoder.singleRBCol)!
     }
 
     func decode(_ type: String.Type) throws -> String {
-        return try self.decoder.doDecode(0)!
+        return try self.decoder.doDecode(self.decoder.singleRBCol)!
     }
 
     func decode(_ type: Double.Type) throws -> Double {
-        return try self.decoder.doDecode(0)!
+        return try self.decoder.doDecode(self.decoder.singleRBCol)!
     }
 
     func decode(_ type: Float.Type) throws -> Float {
-        return try self.decoder.doDecode(0)!
+        return try self.decoder.doDecode(self.decoder.singleRBCol)!
     }
 
     func decode(_ type: Int.Type) throws -> Int {
@@ -318,19 +338,19 @@ private struct ArrowSingleValueDecoding: 
SingleValueDecodingContainer {
     }
 
     func decode(_ type: Int8.Type) throws -> Int8 {
-        return try self.decoder.doDecode(0)!
+        return try self.decoder.doDecode(self.decoder.singleRBCol)!
     }
 
     func decode(_ type: Int16.Type) throws -> Int16 {
-        return try self.decoder.doDecode(0)!
+        return try self.decoder.doDecode(self.decoder.singleRBCol)!
     }
 
     func decode(_ type: Int32.Type) throws -> Int32 {
-        return try self.decoder.doDecode(0)!
+        return try self.decoder.doDecode(self.decoder.singleRBCol)!
     }
 
     func decode(_ type: Int64.Type) throws -> Int64 {
-        return try self.decoder.doDecode(0)!
+        return try self.decoder.doDecode(self.decoder.singleRBCol)!
     }
 
     func decode(_ type: UInt.Type) throws -> UInt {
@@ -339,30 +359,24 @@ private struct ArrowSingleValueDecoding: 
SingleValueDecodingContainer {
     }
 
     func decode(_ type: UInt8.Type) throws -> UInt8 {
-        return try self.decoder.doDecode(0)!
+        return try self.decoder.doDecode(self.decoder.singleRBCol)!
     }
 
     func decode(_ type: UInt16.Type) throws -> UInt16 {
-        return try self.decoder.doDecode(0)!
+        return try self.decoder.doDecode(self.decoder.singleRBCol)!
     }
 
     func decode(_ type: UInt32.Type) throws -> UInt32 {
-        return try self.decoder.doDecode(0)!
+        return try self.decoder.doDecode(self.decoder.singleRBCol)!
     }
 
     func decode(_ type: UInt64.Type) throws -> UInt64 {
-        return try self.decoder.doDecode(0)!
+        return try self.decoder.doDecode(self.decoder.singleRBCol)!
     }
 
     func decode<T>(_ type: T.Type) throws -> T where T: Decodable {
-        if type == Int8.self || type == Int16.self ||
-            type == Int32.self || type == Int64.self ||
-            type == UInt8.self || type == UInt16.self ||
-            type == UInt32.self || type == UInt64.self ||
-            type == String.self || type == Double.self ||
-            type == Float.self || type == Date.self ||
-            type == Bool.self {
-            return try self.decoder.doDecode(0)!
+        if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self {
+            return try self.decoder.doDecode(self.decoder.singleRBCol)!
         } else {
             throw ArrowError.invalid("Type \(type) is currently not supported")
         }
diff --git a/swift/Arrow/Tests/ArrowTests/CodableTests.swift 
b/swift/Arrow/Tests/ArrowTests/CodableTests.swift
index 160beea17c..400faa9f29 100644
--- a/swift/Arrow/Tests/ArrowTests/CodableTests.swift
+++ b/swift/Arrow/Tests/ArrowTests/CodableTests.swift
@@ -166,35 +166,45 @@ final class CodableTests: XCTestCase {
         }
     }
 
-    func testArrowUnkeyedDecoderWithoutNull() throws {
+    func testArrowMapDecoderWithoutNull() throws {
         let int8Builder: NumberArrayBuilder<Int8> = try 
ArrowArrayBuilders.loadNumberArrayBuilder()
         let stringBuilder = try ArrowArrayBuilders.loadStringArrayBuilder()
         int8Builder.append(10, 11, 12, 13)
-        stringBuilder.append("test0", "test1", "test2", "test3")
-        let result = RecordBatch.Builder()
+        stringBuilder.append("test10", "test11", "test12", "test13")
+        switch RecordBatch.Builder()
             .addColumn("propInt8", arrowArray: try int8Builder.toHolder())
             .addColumn("propString", arrowArray: try stringBuilder.toHolder())
-            .finish()
-        switch result {
+            .finish() {
         case .success(let rb):
             let decoder = ArrowDecoder(rb)
             let testData = try decoder.decode([Int8: String].self)
-            var index: Int8 = 0
             for data in testData {
-                let str = data[10 + index]
-                XCTAssertEqual(str, "test\(index)")
-                index += 1
+                XCTAssertEqual("test\(data.key)", data.value)
+            }
+        case .failure(let err):
+            throw err
+        }
+
+        switch RecordBatch.Builder()
+            .addColumn("propString", arrowArray: try stringBuilder.toHolder())
+            .addColumn("propInt8", arrowArray: try int8Builder.toHolder())
+            .finish() {
+        case .success(let rb):
+            let decoder = ArrowDecoder(rb)
+            let testData = try decoder.decode([String: Int8].self)
+            for data in testData {
+                XCTAssertEqual("test\(data.value)", data.key)
             }
         case .failure(let err):
             throw err
         }
     }
 
-    func testArrowUnkeyedDecoderWithNull() throws {
+    func testArrowMapDecoderWithNull() throws {
         let int8Builder: NumberArrayBuilder<Int8> = try 
ArrowArrayBuilders.loadNumberArrayBuilder()
         let stringWNilBuilder = try ArrowArrayBuilders.loadStringArrayBuilder()
         int8Builder.append(10, 11, 12, 13)
-        stringWNilBuilder.append(nil, "test1", nil, "test3")
+        stringWNilBuilder.append(nil, "test11", nil, "test13")
         let resultWNil = RecordBatch.Builder()
             .addColumn("propInt8", arrowArray: try int8Builder.toHolder())
             .addColumn("propString", arrowArray: try 
stringWNilBuilder.toHolder())
@@ -203,19 +213,16 @@ final class CodableTests: XCTestCase {
         case .success(let rb):
             let decoder = ArrowDecoder(rb)
             let testData = try decoder.decode([Int8: String?].self)
-            var index: Int8 = 0
             for data in testData {
-                let str = data[10 + index]
-                if index % 2 == 0 {
-                    XCTAssertNil(str!)
+                let str = data.value
+                if data.key % 2 == 0 {
+                    XCTAssertNil(str)
                 } else {
-                    XCTAssertEqual(str, "test\(index)")
+                    XCTAssertEqual(str, "test\(data.key)")
                 }
-                index += 1
             }
         case .failure(let err):
             throw err
         }
-
     }
 }

Reply via email to