Github user HyukjinKwon commented on a diff in the pull request:
https://github.com/apache/spark/pull/21546#discussion_r194965715
--- Diff:
sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
---
@@ -1318,18 +1318,52 @@ class ArrowConvertersSuite extends SharedSQLContext
with BeforeAndAfterAll {
}
}
- test("roundtrip payloads") {
+ test("roundtrip arrow batches") {
val inputRows = (0 until 9).map { i =>
InternalRow(i)
} :+ InternalRow(null)
val schema = StructType(Seq(StructField("int", IntegerType, nullable =
true)))
val ctx = TaskContext.empty()
- val payloadIter =
ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, null, ctx)
- val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter,
ctx)
+ val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator,
schema, 5, null, ctx)
+ val outputRowIter = ArrowConverters.fromBatchIterator(batchIter,
schema, null, ctx)
- assert(schema == outputRowIter.schema)
+ var count = 0
+ outputRowIter.zipWithIndex.foreach { case (row, i) =>
+ if (i != 9) {
+ assert(row.getInt(0) == i)
+ } else {
+ assert(row.isNullAt(0))
+ }
+ count += 1
+ }
+
+ assert(count == inputRows.length)
+ }
+
+ test("ArrowBatchStreamWriter roundtrip") {
+ val inputRows = (0 until 9).map { i =>
+ InternalRow(i)
+ } :+ InternalRow(null)
+
+ val schema = StructType(Seq(StructField("int", IntegerType, nullable =
true)))
+
+ val ctx = TaskContext.empty()
+ val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator,
schema, 5, null, ctx)
+
+ // Write batches to Arrow stream format as a byte array
+ val out = new ByteArrayOutputStream()
--- End diff --
Can we use `Utils.tryWithResource { ... }`?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]