This is an automated email from the ASF dual-hosted git repository.
kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new e632364655 GH-39519: [Swift] Fix null count when using reader (#39520)
e632364655 is described below
commit e6323646558ee01234ce58af273c5a834745f298
Author: abandy <[email protected]>
AuthorDate: Sat Jan 13 17:02:06 2024 -0500
GH-39519: [Swift] Fix null count when using reader (#39520)
Currently the reader is not properly setting the null count when building
an array from a stream. This PR adds a fix for this.
* Closes: #39519
Authored-by: Alva Bandy <[email protected]>
Signed-off-by: Sutou Kouhei <[email protected]>
---
swift/Arrow/Sources/Arrow/ArrowReader.swift | 12 ++--
swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift | 82 +++++++++++++---------
swift/Arrow/Tests/ArrowTests/IPCTests.swift | 40 +++++++++--
.../Arrow/Tests/ArrowTests/RecordBatchTests.swift | 9 ++-
4 files changed, 96 insertions(+), 47 deletions(-)
diff --git a/swift/Arrow/Sources/Arrow/ArrowReader.swift
b/swift/Arrow/Sources/Arrow/ArrowReader.swift
index d9dc1bdb47..237f22dc97 100644
--- a/swift/Arrow/Sources/Arrow/ArrowReader.swift
+++ b/swift/Arrow/Sources/Arrow/ArrowReader.swift
@@ -57,15 +57,17 @@ public class ArrowReader {
private func loadPrimitiveData(_ loadInfo: DataLoadInfo) ->
Result<ArrowArrayHolder, ArrowError> {
do {
let node = loadInfo.recordBatch.nodes(at: loadInfo.nodeIndex)!
+ let nullLength = UInt(ceil(Double(node.length) / 8))
try validateBufferIndex(loadInfo.recordBatch, index:
loadInfo.bufferIndex)
let nullBuffer = loadInfo.recordBatch.buffers(at:
loadInfo.bufferIndex)!
let arrowNullBuffer = makeBuffer(nullBuffer, fileData:
loadInfo.fileData,
- length: UInt(node.nullCount),
messageOffset: loadInfo.messageOffset)
+ length: nullLength,
messageOffset: loadInfo.messageOffset)
try validateBufferIndex(loadInfo.recordBatch, index:
loadInfo.bufferIndex + 1)
let valueBuffer = loadInfo.recordBatch.buffers(at:
loadInfo.bufferIndex + 1)!
let arrowValueBuffer = makeBuffer(valueBuffer, fileData:
loadInfo.fileData,
length: UInt(node.length),
messageOffset: loadInfo.messageOffset)
- return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer,
arrowValueBuffer])
+ return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer,
arrowValueBuffer],
+ nullCount: UInt(node.nullCount))
} catch let error as ArrowError {
return .failure(error)
} catch {
@@ -76,10 +78,11 @@ public class ArrowReader {
private func loadVariableData(_ loadInfo: DataLoadInfo) ->
Result<ArrowArrayHolder, ArrowError> {
let node = loadInfo.recordBatch.nodes(at: loadInfo.nodeIndex)!
do {
+ let nullLength = UInt(ceil(Double(node.length) / 8))
try validateBufferIndex(loadInfo.recordBatch, index:
loadInfo.bufferIndex)
let nullBuffer = loadInfo.recordBatch.buffers(at:
loadInfo.bufferIndex)!
let arrowNullBuffer = makeBuffer(nullBuffer, fileData:
loadInfo.fileData,
- length: UInt(node.nullCount),
messageOffset: loadInfo.messageOffset)
+ length: nullLength,
messageOffset: loadInfo.messageOffset)
try validateBufferIndex(loadInfo.recordBatch, index:
loadInfo.bufferIndex + 1)
let offsetBuffer = loadInfo.recordBatch.buffers(at:
loadInfo.bufferIndex + 1)!
let arrowOffsetBuffer = makeBuffer(offsetBuffer, fileData:
loadInfo.fileData,
@@ -88,7 +91,8 @@ public class ArrowReader {
let valueBuffer = loadInfo.recordBatch.buffers(at:
loadInfo.bufferIndex + 2)!
let arrowValueBuffer = makeBuffer(valueBuffer, fileData:
loadInfo.fileData,
length: UInt(node.length),
messageOffset: loadInfo.messageOffset)
- return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer,
arrowOffsetBuffer, arrowValueBuffer])
+ return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer,
arrowOffsetBuffer, arrowValueBuffer],
+ nullCount: UInt(node.nullCount))
} catch let error as ArrowError {
return .failure(error)
} catch {
diff --git a/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift
b/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift
index fa52160478..7b3ec04b3a 100644
--- a/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift
+++ b/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift
@@ -18,10 +18,11 @@
import FlatBuffers
import Foundation
-private func makeBinaryHolder(_ buffers: [ArrowBuffer]) ->
Result<ArrowArrayHolder, ArrowError> {
+private func makeBinaryHolder(_ buffers: [ArrowBuffer],
+ nullCount: UInt) -> Result<ArrowArrayHolder,
ArrowError> {
do {
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowBinary),
buffers: buffers,
- nullCount: buffers[0].length, stride:
MemoryLayout<Int8>.stride)
+ nullCount: nullCount, stride:
MemoryLayout<Int8>.stride)
return .success(ArrowArrayHolder(BinaryArray(arrowData)))
} catch let error as ArrowError {
return .failure(error)
@@ -30,10 +31,11 @@ private func makeBinaryHolder(_ buffers: [ArrowBuffer]) ->
Result<ArrowArrayHold
}
}
-private func makeStringHolder(_ buffers: [ArrowBuffer]) ->
Result<ArrowArrayHolder, ArrowError> {
+private func makeStringHolder(_ buffers: [ArrowBuffer],
+ nullCount: UInt) -> Result<ArrowArrayHolder,
ArrowError> {
do {
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowString),
buffers: buffers,
- nullCount: buffers[0].length, stride:
MemoryLayout<Int8>.stride)
+ nullCount: nullCount, stride:
MemoryLayout<Int8>.stride)
return .success(ArrowArrayHolder(StringArray(arrowData)))
} catch let error as ArrowError {
return .failure(error)
@@ -43,30 +45,32 @@ private func makeStringHolder(_ buffers: [ArrowBuffer]) ->
Result<ArrowArrayHold
}
private func makeFloatHolder(_ floatType:
org_apache_arrow_flatbuf_FloatingPoint,
- buffers: [ArrowBuffer]
+ buffers: [ArrowBuffer],
+ nullCount: UInt
) -> Result<ArrowArrayHolder, ArrowError> {
switch floatType.precision {
case .single:
- return makeFixedHolder(Float.self, buffers: buffers, arrowType:
ArrowType.ArrowFloat)
+ return makeFixedHolder(Float.self, buffers: buffers, arrowType:
ArrowType.ArrowFloat, nullCount: nullCount)
case .double:
- return makeFixedHolder(Double.self, buffers: buffers, arrowType:
ArrowType.ArrowDouble)
+ return makeFixedHolder(Double.self, buffers: buffers, arrowType:
ArrowType.ArrowDouble, nullCount: nullCount)
default:
return .failure(.unknownType("Float precision \(floatType.precision)
currently not supported"))
}
}
private func makeDateHolder(_ dateType: org_apache_arrow_flatbuf_Date,
- buffers: [ArrowBuffer]
+ buffers: [ArrowBuffer],
+ nullCount: UInt
) -> Result<ArrowArrayHolder, ArrowError> {
do {
if dateType.unit == .day {
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowString),
buffers: buffers,
- nullCount: buffers[0].length,
stride: MemoryLayout<Date>.stride)
+ nullCount: nullCount, stride:
MemoryLayout<Date>.stride)
return .success(ArrowArrayHolder(Date32Array(arrowData)))
}
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowString),
buffers: buffers,
- nullCount: buffers[0].length, stride:
MemoryLayout<Date>.stride)
+ nullCount: nullCount, stride:
MemoryLayout<Date>.stride)
return .success(ArrowArrayHolder(Date64Array(arrowData)))
} catch let error as ArrowError {
return .failure(error)
@@ -76,19 +80,20 @@ private func makeDateHolder(_ dateType:
org_apache_arrow_flatbuf_Date,
}
private func makeTimeHolder(_ timeType: org_apache_arrow_flatbuf_Time,
- buffers: [ArrowBuffer]
+ buffers: [ArrowBuffer],
+ nullCount: UInt
) -> Result<ArrowArrayHolder, ArrowError> {
do {
if timeType.unit == .second || timeType.unit == .millisecond {
let arrowUnit: ArrowTime32Unit = timeType.unit == .second ?
.seconds : .milliseconds
let arrowData = try ArrowData(ArrowTypeTime32(arrowUnit), buffers:
buffers,
- nullCount: buffers[0].length,
stride: MemoryLayout<Time32>.stride)
+ nullCount: nullCount, stride:
MemoryLayout<Time32>.stride)
return .success(ArrowArrayHolder(FixedArray<Time32>(arrowData)))
}
let arrowUnit: ArrowTime64Unit = timeType.unit == .microsecond ?
.microseconds : .nanoseconds
let arrowData = try ArrowData(ArrowTypeTime64(arrowUnit), buffers:
buffers,
- nullCount: buffers[0].length, stride:
MemoryLayout<Time64>.stride)
+ nullCount: nullCount, stride:
MemoryLayout<Time64>.stride)
return .success(ArrowArrayHolder(FixedArray<Time64>(arrowData)))
} catch let error as ArrowError {
return .failure(error)
@@ -97,10 +102,11 @@ private func makeTimeHolder(_ timeType:
org_apache_arrow_flatbuf_Time,
}
}
-private func makeBoolHolder(_ buffers: [ArrowBuffer]) ->
Result<ArrowArrayHolder, ArrowError> {
+private func makeBoolHolder(_ buffers: [ArrowBuffer],
+ nullCount: UInt) -> Result<ArrowArrayHolder,
ArrowError> {
do {
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowBool), buffers:
buffers,
- nullCount: buffers[0].length, stride:
MemoryLayout<UInt8>.stride)
+ nullCount: nullCount, stride:
MemoryLayout<UInt8>.stride)
return .success(ArrowArrayHolder(BoolArray(arrowData)))
} catch let error as ArrowError {
return .failure(error)
@@ -111,11 +117,12 @@ private func makeBoolHolder(_ buffers: [ArrowBuffer]) ->
Result<ArrowArrayHolder
private func makeFixedHolder<T>(
_: T.Type, buffers: [ArrowBuffer],
- arrowType: ArrowType.Info
+ arrowType: ArrowType.Info,
+ nullCount: UInt
) -> Result<ArrowArrayHolder, ArrowError> {
do {
let arrowData = try ArrowData(ArrowType(arrowType), buffers: buffers,
- nullCount: buffers[0].length, stride:
MemoryLayout<T>.stride)
+ nullCount: nullCount, stride:
MemoryLayout<T>.stride)
return .success(ArrowArrayHolder(FixedArray<T>(arrowData)))
} catch let error as ArrowError {
return .failure(error)
@@ -124,9 +131,10 @@ private func makeFixedHolder<T>(
}
}
-func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity
+func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity
function_body_length
_ field: org_apache_arrow_flatbuf_Field,
- buffers: [ArrowBuffer]
+ buffers: [ArrowBuffer],
+ nullCount: UInt
) -> Result<ArrowArrayHolder, ArrowError> {
let type = field.typeType
switch type {
@@ -135,45 +143,53 @@ func makeArrayHolder( // swiftlint:disable:this
cyclomatic_complexity
let bitWidth = intType.bitWidth
if bitWidth == 8 {
if intType.isSigned {
- return makeFixedHolder(Int8.self, buffers: buffers, arrowType:
ArrowType.ArrowInt8)
+ return makeFixedHolder(Int8.self, buffers: buffers,
+ arrowType: ArrowType.ArrowInt8,
nullCount: nullCount)
} else {
- return makeFixedHolder(UInt8.self, buffers: buffers,
arrowType: ArrowType.ArrowUInt8)
+ return makeFixedHolder(UInt8.self, buffers: buffers,
+ arrowType: ArrowType.ArrowUInt8,
nullCount: nullCount)
}
} else if bitWidth == 16 {
if intType.isSigned {
- return makeFixedHolder(Int16.self, buffers: buffers,
arrowType: ArrowType.ArrowInt16)
+ return makeFixedHolder(Int16.self, buffers: buffers,
+ arrowType: ArrowType.ArrowInt16,
nullCount: nullCount)
} else {
- return makeFixedHolder(UInt16.self, buffers: buffers,
arrowType: ArrowType.ArrowUInt16)
+ return makeFixedHolder(UInt16.self, buffers: buffers,
+ arrowType: ArrowType.ArrowUInt16,
nullCount: nullCount)
}
} else if bitWidth == 32 {
if intType.isSigned {
- return makeFixedHolder(Int32.self, buffers: buffers,
arrowType: ArrowType.ArrowInt32)
+ return makeFixedHolder(Int32.self, buffers: buffers,
+ arrowType: ArrowType.ArrowInt32,
nullCount: nullCount)
} else {
- return makeFixedHolder(UInt32.self, buffers: buffers,
arrowType: ArrowType.ArrowUInt32)
+ return makeFixedHolder(UInt32.self, buffers: buffers,
+ arrowType: ArrowType.ArrowUInt32,
nullCount: nullCount)
}
} else if bitWidth == 64 {
if intType.isSigned {
- return makeFixedHolder(Int64.self, buffers: buffers,
arrowType: ArrowType.ArrowInt64)
+ return makeFixedHolder(Int64.self, buffers: buffers,
+ arrowType: ArrowType.ArrowInt64,
nullCount: nullCount)
} else {
- return makeFixedHolder(UInt64.self, buffers: buffers,
arrowType: ArrowType.ArrowUInt64)
+ return makeFixedHolder(UInt64.self, buffers: buffers,
+ arrowType: ArrowType.ArrowUInt64,
nullCount: nullCount)
}
}
return .failure(.unknownType("Int width \(bitWidth) currently not
supported"))
case .bool:
- return makeBoolHolder(buffers)
+ return makeBoolHolder(buffers, nullCount: nullCount)
case .floatingpoint:
let floatType = field.type(type:
org_apache_arrow_flatbuf_FloatingPoint.self)!
- return makeFloatHolder(floatType, buffers: buffers)
+ return makeFloatHolder(floatType, buffers: buffers, nullCount:
nullCount)
case .utf8:
- return makeStringHolder(buffers)
+ return makeStringHolder(buffers, nullCount: nullCount)
case .binary:
- return makeBinaryHolder(buffers)
+ return makeBinaryHolder(buffers, nullCount: nullCount)
case .date:
let dateType = field.type(type: org_apache_arrow_flatbuf_Date.self)!
- return makeDateHolder(dateType, buffers: buffers)
+ return makeDateHolder(dateType, buffers: buffers, nullCount: nullCount)
case .time:
let timeType = field.type(type: org_apache_arrow_flatbuf_Time.self)!
- return makeTimeHolder(timeType, buffers: buffers)
+ return makeTimeHolder(timeType, buffers: buffers, nullCount: nullCount)
default:
return .failure(.unknownType("Type \(type) currently not supported"))
}
diff --git a/swift/Arrow/Tests/ArrowTests/IPCTests.swift
b/swift/Arrow/Tests/ArrowTests/IPCTests.swift
index 59cad94ef4..103c3b24c7 100644
--- a/swift/Arrow/Tests/ArrowTests/IPCTests.swift
+++ b/swift/Arrow/Tests/ArrowTests/IPCTests.swift
@@ -64,14 +64,16 @@ func makeSchema() -> ArrowSchema {
return schemaBuilder.addField("col1", type:
ArrowType(ArrowType.ArrowUInt8), isNullable: true)
.addField("col2", type: ArrowType(ArrowType.ArrowString), isNullable:
false)
.addField("col3", type: ArrowType(ArrowType.ArrowDate32), isNullable:
false)
+ .addField("col4", type: ArrowType(ArrowType.ArrowInt32), isNullable:
false)
+ .addField("col5", type: ArrowType(ArrowType.ArrowFloat), isNullable:
false)
.finish()
}
func makeRecordBatch() throws -> RecordBatch {
let uint8Builder: NumberArrayBuilder<UInt8> = try
ArrowArrayBuilders.loadNumberArrayBuilder()
uint8Builder.append(10)
- uint8Builder.append(22)
- uint8Builder.append(33)
+ uint8Builder.append(nil)
+ uint8Builder.append(nil)
uint8Builder.append(44)
let stringBuilder = try ArrowArrayBuilders.loadStringArrayBuilder()
stringBuilder.append("test10")
@@ -85,13 +87,28 @@ func makeRecordBatch() throws -> RecordBatch {
date32Builder.append(date2)
date32Builder.append(date1)
date32Builder.append(date2)
- let intHolder = ArrowArrayHolder(try uint8Builder.finish())
+ let int32Builder: NumberArrayBuilder<Int32> = try
ArrowArrayBuilders.loadNumberArrayBuilder()
+ int32Builder.append(1)
+ int32Builder.append(2)
+ int32Builder.append(3)
+ int32Builder.append(4)
+ let floatBuilder: NumberArrayBuilder<Float> = try
ArrowArrayBuilders.loadNumberArrayBuilder()
+ floatBuilder.append(211.112)
+ floatBuilder.append(322.223)
+ floatBuilder.append(433.334)
+ floatBuilder.append(544.445)
+
+ let uint8Holder = ArrowArrayHolder(try uint8Builder.finish())
let stringHolder = ArrowArrayHolder(try stringBuilder.finish())
let date32Holder = ArrowArrayHolder(try date32Builder.finish())
+ let int32Holder = ArrowArrayHolder(try int32Builder.finish())
+ let floatHolder = ArrowArrayHolder(try floatBuilder.finish())
let result = RecordBatch.Builder()
- .addColumn("col1", arrowArray: intHolder)
+ .addColumn("col1", arrowArray: uint8Holder)
.addColumn("col2", arrowArray: stringHolder)
.addColumn("col3", arrowArray: date32Holder)
+ .addColumn("col4", arrowArray: int32Holder)
+ .addColumn("col5", arrowArray: floatHolder)
.finish()
switch result {
case .success(let recordBatch):
@@ -182,15 +199,20 @@ final class IPCFileReaderTests: XCTestCase {
XCTAssertEqual(recordBatches.count, 1)
for recordBatch in recordBatches {
XCTAssertEqual(recordBatch.length, 4)
- XCTAssertEqual(recordBatch.columns.count, 3)
- XCTAssertEqual(recordBatch.schema.fields.count, 3)
+ XCTAssertEqual(recordBatch.columns.count, 5)
+ XCTAssertEqual(recordBatch.schema.fields.count, 5)
XCTAssertEqual(recordBatch.schema.fields[0].name, "col1")
XCTAssertEqual(recordBatch.schema.fields[0].type.info,
ArrowType.ArrowUInt8)
XCTAssertEqual(recordBatch.schema.fields[1].name, "col2")
XCTAssertEqual(recordBatch.schema.fields[1].type.info,
ArrowType.ArrowString)
XCTAssertEqual(recordBatch.schema.fields[2].name, "col3")
XCTAssertEqual(recordBatch.schema.fields[2].type.info,
ArrowType.ArrowDate32)
+ XCTAssertEqual(recordBatch.schema.fields[3].name, "col4")
+ XCTAssertEqual(recordBatch.schema.fields[3].type.info,
ArrowType.ArrowInt32)
+ XCTAssertEqual(recordBatch.schema.fields[4].name, "col5")
+ XCTAssertEqual(recordBatch.schema.fields[4].type.info,
ArrowType.ArrowFloat)
let columns = recordBatch.columns
+ XCTAssertEqual(columns[0].nullCount, 2)
let dateVal =
"\((columns[2].array as! AsString).asString(0))" //
swiftlint:disable:this force_cast
XCTAssertEqual(dateVal, "2014-09-10 00:00:00 +0000")
@@ -227,13 +249,17 @@ final class IPCFileReaderTests: XCTestCase {
case .success(let result):
XCTAssertNotNil(result.schema)
let schema = result.schema!
- XCTAssertEqual(schema.fields.count, 3)
+ XCTAssertEqual(schema.fields.count, 5)
XCTAssertEqual(schema.fields[0].name, "col1")
XCTAssertEqual(schema.fields[0].type.info,
ArrowType.ArrowUInt8)
XCTAssertEqual(schema.fields[1].name, "col2")
XCTAssertEqual(schema.fields[1].type.info,
ArrowType.ArrowString)
XCTAssertEqual(schema.fields[2].name, "col3")
XCTAssertEqual(schema.fields[2].type.info,
ArrowType.ArrowDate32)
+ XCTAssertEqual(schema.fields[3].name, "col4")
+ XCTAssertEqual(schema.fields[3].type.info,
ArrowType.ArrowInt32)
+ XCTAssertEqual(schema.fields[4].name, "col5")
+ XCTAssertEqual(schema.fields[4].type.info,
ArrowType.ArrowFloat)
case.failure(let error):
throw error
}
diff --git a/swift/Arrow/Tests/ArrowTests/RecordBatchTests.swift
b/swift/Arrow/Tests/ArrowTests/RecordBatchTests.swift
index ab6cad1b5e..8820f1cdb1 100644
--- a/swift/Arrow/Tests/ArrowTests/RecordBatchTests.swift
+++ b/swift/Arrow/Tests/ArrowTests/RecordBatchTests.swift
@@ -23,9 +23,11 @@ final class RecordBatchTests: XCTestCase {
let uint8Builder: NumberArrayBuilder<UInt8> = try
ArrowArrayBuilders.loadNumberArrayBuilder()
uint8Builder.append(10)
uint8Builder.append(22)
+ uint8Builder.append(nil)
let stringBuilder = try ArrowArrayBuilders.loadStringArrayBuilder()
stringBuilder.append("test10")
stringBuilder.append("test22")
+ stringBuilder.append("test33")
let intHolder = ArrowArrayHolder(try uint8Builder.finish())
let stringHolder = ArrowArrayHolder(try stringBuilder.finish())
@@ -39,15 +41,16 @@ final class RecordBatchTests: XCTestCase {
XCTAssertEqual(schema.fields.count, 2)
XCTAssertEqual(schema.fields[0].name, "col1")
XCTAssertEqual(schema.fields[0].type.info, ArrowType.ArrowUInt8)
- XCTAssertEqual(schema.fields[0].isNullable, false)
+ XCTAssertEqual(schema.fields[0].isNullable, true)
XCTAssertEqual(schema.fields[1].name, "col2")
XCTAssertEqual(schema.fields[1].type.info, ArrowType.ArrowString)
XCTAssertEqual(schema.fields[1].isNullable, false)
XCTAssertEqual(recordBatch.columns.count, 2)
let col1: ArrowArray<UInt8> = recordBatch.data(for: 0)
let col2: ArrowArray<String> = recordBatch.data(for: 1)
- XCTAssertEqual(col1.length, 2)
- XCTAssertEqual(col2.length, 2)
+ XCTAssertEqual(col1.length, 3)
+ XCTAssertEqual(col2.length, 3)
+ XCTAssertEqual(col1.nullCount, 1)
case .failure(let error):
throw error
}