This is an automated email from the ASF dual-hosted git repository.
alexey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kudu.git
The following commit(s) were added to refs/heads/master by this push:
new 1d0cb013a KUDU-1261 [Java] Add spark bindings for Array columns
1d0cb013a is described below
commit 1d0cb013a5aad754f2ed06c25f5ec3d8a1873622
Author: Abhishek Chennaka <[email protected]>
AuthorDate: Mon Oct 20 15:12:10 2025 -0700
KUDU-1261 [Java] Add spark bindings for Array columns
Introduces end-to-end support for Array types in Kudu Spark bindings.
Extends RowConverter, SparkUtil, and KuduRDD to serialize and deserialize
array columns between Kudu and Spark SQL. Adds integration tests covering
read/write and round-trip conversion for primitive and string array types.
Updates the kudu-backup metadata structure to accommodate array type
column schema.
Change-Id: I786fdd1cbfbb67b4895b2e95b89addbc04341746
Reviewed-on: http://gerrit.cloudera.org:8080/23565
Tested-by: Alexey Serbin <[email protected]>
Reviewed-by: Alexey Serbin <[email protected]>
---
.../src/main/protobuf/backup.proto | 12 ++
.../org/apache/kudu/backup/TableMetadata.scala | 40 +++-
.../org/apache/kudu/backup/TestKuduBackup.scala | 5 +-
.../org/apache/kudu/spark/kudu/KuduContext.scala | 9 +-
.../scala/org/apache/kudu/spark/kudu/KuduRDD.scala | 12 +-
.../org/apache/kudu/spark/kudu/RowConverter.scala | 211 ++++++++++++++++++++-
.../org/apache/kudu/spark/kudu/SparkUtil.scala | 124 ++++++++----
.../apache/kudu/spark/kudu/DefaultSourceTest.scala | 2 +-
.../apache/kudu/spark/kudu/KuduContextTest.scala | 155 +++++++++++++++
.../org/apache/kudu/spark/kudu/KuduTestSuite.scala | 153 ++++++++++++++-
.../org/apache/kudu/spark/kudu/SparkSQLTest.scala | 117 ++++++++++++
11 files changed, 787 insertions(+), 53 deletions(-)
diff --git a/java/kudu-backup-common/src/main/protobuf/backup.proto
b/java/kudu-backup-common/src/main/protobuf/backup.proto
index 77a13b3e3..7ddc28dc7 100644
--- a/java/kudu-backup-common/src/main/protobuf/backup.proto
+++ b/java/kudu-backup-common/src/main/protobuf/backup.proto
@@ -50,6 +50,18 @@ message ColumnMetadataPB {
int32 block_size = 9;
string comment = 10;
bool is_auto_incrementing = 11;
+ // Descriptor present only when type == "NESTED".
+ NestedTypeDescriptorMetadataPB nested_type_descriptor = 12;
+}
+
+// Stores nested column type information
+message NestedTypeDescriptorMetadataPB {
+ ArrayTypeDescriptorMetadataPB array_descriptor = 1;
+}
+
+// Stores array column (subtype of nested) type information
+message ArrayTypeDescriptorMetadataPB {
+ string elem_type = 1;
}
// A human readable string representation of a column value for use
diff --git
a/java/kudu-backup-common/src/main/scala/org/apache/kudu/backup/TableMetadata.scala
b/java/kudu-backup-common/src/main/scala/org/apache/kudu/backup/TableMetadata.scala
index d532e744e..a5278153e 100644
---
a/java/kudu-backup-common/src/main/scala/org/apache/kudu/backup/TableMetadata.scala
+++
b/java/kudu-backup-common/src/main/scala/org/apache/kudu/backup/TableMetadata.scala
@@ -18,7 +18,6 @@
package org.apache.kudu.backup
import java.math.BigDecimal
-import java.nio.ByteBuffer
import java.sql.Date
import java.util
@@ -78,6 +77,21 @@ object TableMetadata {
if (col.getDefaultValue != null) {
builder.setDefaultValue(StringValue.of(valueToString(col.getDefaultValue,
col.getType)))
}
+ if (col.getType == Type.NESTED && col.getNestedTypeDescriptor != null) {
+ val nestedDesc = col.getNestedTypeDescriptor
+ if (nestedDesc.isArray) {
+ val arrDesc = nestedDesc.getArrayDescriptor
+ val arrayBuilder = ArrayTypeDescriptorMetadataPB
+ .newBuilder()
+ .setElemType(arrDesc.getElemType.name())
+ val nestedBuilder = NestedTypeDescriptorMetadataPB
+ .newBuilder()
+ .setArrayDescriptor(arrayBuilder.build())
+ builder
+ .setType(Type.NESTED.name())
+ .setNestedTypeDescriptor(nestedBuilder.build())
+ }
+ }
builder.build()
}
@@ -249,13 +263,33 @@ object TableMetadata {
val toId = metadata.getColumnIdsMap.asScala
metadata.getColumnsList.asScala.foreach { col =>
if (!col.getIsAutoIncrementing) {
- val colType = Type.getTypeForName(col.getType)
+ var colType = Type.getTypeForName(col.getType)
+ var isArray = false
+
+ // Detect and reconstruct array columns
+ if (colType == Type.NESTED) {
+ if (!col.hasNestedTypeDescriptor) {
+ throw new IllegalArgumentException(
+ s"Column ${col.getName} is marked NESTED but missing
NestedTypeDescriptor")
+ }
+ val nestedPB = col.getNestedTypeDescriptor
+ if (!nestedPB.hasArrayDescriptor) {
+ throw new IllegalArgumentException(
+ s"Column ${col.getName} is NESTED but not an ARRAY (unsupported
nested subtype)")
+ }
+ val arrayPB = nestedPB.getArrayDescriptor
+ colType = Type.getTypeForName(arrayPB.getElemType)
+ isArray = true
+ }
+
val builder = new ColumnSchemaBuilder(col.getName, colType)
.nullable(col.getIsNullable)
.encoding(Encoding.valueOf(col.getEncoding))
.compressionAlgorithm(CompressionAlgorithm.valueOf(col.getCompression))
.desiredBlockSize(col.getBlockSize)
.comment(col.getComment)
+ .array(isArray)
+
if (IsAutoIncrementingPresent) {
builder.nonUniqueKey(col.getIsKey)
} else {
@@ -277,10 +311,12 @@ object TableMetadata {
.build()
)
}
+
colIds.add(toId(col.getName))
columns.add(builder.build())
}
}
+
new Schema(columns, colIds)
}
diff --git
a/java/kudu-backup/src/test/scala/org/apache/kudu/backup/TestKuduBackup.scala
b/java/kudu-backup/src/test/scala/org/apache/kudu/backup/TestKuduBackup.scala
index e2f2c481a..54653418a 100644
---
a/java/kudu-backup/src/test/scala/org/apache/kudu/backup/TestKuduBackup.scala
+++
b/java/kudu-backup/src/test/scala/org/apache/kudu/backup/TestKuduBackup.scala
@@ -1188,8 +1188,9 @@ class TestKuduBackup extends KuduTestSuite {
Objects.equal(before.getEncoding, after.getEncoding) &&
Objects
.equal(before.getCompressionAlgorithm, after.getCompressionAlgorithm) &&
- Objects.equal(before.getTypeAttributes, after.getTypeAttributes)
- Objects.equal(before.getComment, after.getComment)
+ Objects.equal(before.getTypeAttributes, after.getTypeAttributes) &&
+ Objects.equal(before.getComment, after.getComment) &&
+ Objects.equal(before.isArray, after.isArray)
}
// Special handling because default values can be a byte array which is not
diff --git
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
index 4bc00d172..f7c6d98e6 100644
---
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
+++
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
@@ -381,7 +381,14 @@ class KuduContext(
s"adding ${newColumns.length} columns to table '$tableName' to
handle schema drift")
val alter = new AlterTableOptions()
newColumns.foreach { col =>
- alter.addNullableColumn(col.name, sparkTypeToKuduType(col.dataType))
+ col.dataType match {
+ case at: org.apache.spark.sql.types.ArrayType =>
+ val elemType = sparkTypeToKuduType(at.elementType)
+ alter.addNullableArrayColumn(col.name, elemType)
+
+ case _ =>
+ alter.addNullableColumn(col.name,
sparkTypeToKuduType(col.dataType))
+ }
}
try {
syncClient.alterTable(tableName, alter)
diff --git
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala
index 84b004f22..8bf13a49b 100644
--- a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala
+++ b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala
@@ -165,7 +165,17 @@ private class RowIterator(
val columnCount = rowResult.getColumnProjection.getColumnCount
val columns = Array.ofDim[Any](columnCount)
for (i <- 0 until columnCount) {
- columns(i) = rowResult.getObject(i)
+ val col = rowResult.getColumnProjection.getColumnByIndex(i)
+ if (col.isArray) {
+ val arrObj = rowResult.getArrayData(i)
+ columns(i) = if (arrObj == null) {
+ null
+ } else {
+ arrObj.asInstanceOf[Array[_]].toIndexedSeq
+ }
+ } else {
+ columns(i) = rowResult.getObject(i)
+ }
}
Row.fromSeq(columns)
}
diff --git
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/RowConverter.scala
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/RowConverter.scala
index f4c52c82b..3797d061a 100644
---
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/RowConverter.scala
+++
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/RowConverter.scala
@@ -25,11 +25,13 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
+import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.types.StructType
import org.apache.yetus.audience.InterfaceAudience
import org.apache.yetus.audience.InterfaceStability
+import scala.collection.JavaConverters._
@InterfaceAudience.Private
@InterfaceStability.Unstable
@@ -58,16 +60,28 @@ class RowConverter(kuduSchema: Schema, schema: StructType,
ignoreNull: Boolean)
def toPartialRow(row: Row): PartialRow = {
val partialRow = kuduSchema.newPartialRow()
for ((sparkIdx, kuduIdx) <- indices) {
+ val col = kuduSchema.getColumnByIndex(kuduIdx)
if (row.isNullAt(sparkIdx)) {
- if (kuduSchema.getColumnByIndex(kuduIdx).isKey) {
- val key_name = kuduSchema.getColumnByIndex(kuduIdx).getName
- throw new IllegalArgumentException(s"Can't set primary key column
'$key_name' to null")
+ if (col.isKey) {
+ throw new IllegalArgumentException(
+ s"Can't set primary key column '${col.getName}' to null")
}
if (!ignoreNull) partialRow.setNull(kuduIdx)
} else {
schema.fields(sparkIdx).dataType match {
+
+ // ========== ARRAY WRITE ==========
+ case ArrayType(elemType, containsNull) =>
+ val seq = row.getList[Any](sparkIdx).asScala
+ writeArray(
+ partialRow,
+ kuduIdx,
+ col.getNestedTypeDescriptor.getArrayDescriptor.getElemType,
+ seq)
+
+ // ========== SCALAR TYPES ==========
case DataTypes.StringType =>
- kuduSchema.getColumnByIndex(kuduIdx).getType match {
+ col.getType match {
case Type.STRING =>
partialRow.addString(kuduIdx, row.getString(sparkIdx))
case Type.VARCHAR =>
@@ -112,7 +126,17 @@ class RowConverter(kuduSchema: Schema, schema: StructType,
ignoreNull: Boolean)
val columnCount = rowResult.getColumnProjection.getColumnCount
val columns = Array.ofDim[Any](columnCount)
for (i <- 0 until columnCount) {
- columns(i) = rowResult.getObject(i)
+ val col = rowResult.getColumnProjection.getColumnByIndex(i)
+ if (rowResult.isNull(i)) {
+ columns(i) = null
+ } else if (col.isArray) {
+ val arrObj = rowResult.getArrayData(i)
+ columns(i) =
+ if (arrObj == null) null
+ else arrObj.asInstanceOf[Array[_]].toIndexedSeq
+ } else {
+ columns(i) = rowResult.getObject(i)
+ }
}
new GenericRowWithSchema(columns, schema)
}
@@ -124,8 +148,183 @@ class RowConverter(kuduSchema: Schema, schema:
StructType, ignoreNull: Boolean)
val columnCount = partialRow.getSchema.getColumnCount
val columns = Array.ofDim[Any](columnCount)
for (i <- 0 until columnCount) {
- columns(i) = partialRow.getObject(i)
+ val col = partialRow.getSchema.getColumnByIndex(i)
+ if (partialRow.isSet(i)) {
+ if (col.isArray) {
+ val arrObj = partialRow.getArrayData(i)
+ columns(i) =
+ if (arrObj == null) null
+ else arrObj.asInstanceOf[Array[_]].toIndexedSeq
+ } else {
+ columns(i) = partialRow.getObject(i)
+ }
+ } else {
+ columns(i) = null
+ }
}
new GenericRowWithSchema(columns, schema)
}
+
+ // ---------------------------------------------------------------------
+ // Array write helper
+ // ---------------------------------------------------------------------
+ //
+ // Converts a Scala Seq[Any] into Kudu's array cell format for the given
+ // element type. Builds parallel 'data' and 'validity' arrays, then calls
+ // the corresponding PartialRow.addArray*() method (e.g. addArrayInt32,
+ // addArrayString, etc.). Null elements are recorded as invalid in the
+ // validity mask.
+ private def writeArray(pr: PartialRow, idx: Int, elemKuduType: Type, seq:
Seq[Any]): Unit = {
+
+ val n = seq.length
+ val validity = new Array[Boolean](n)
+
+ elemKuduType match {
+ case Type.BOOL =>
+ val data = new Array[Boolean](n)
+ var i = 0
+ while (i < n) {
+ val v = seq(i)
+ if (v == null) validity(i) = false
+ else { validity(i) = true; data(i) = v.asInstanceOf[Boolean] }
+ i += 1
+ }
+ pr.addArrayBool(idx, data, validity)
+
+ case Type.INT8 =>
+ val data = new Array[Byte](n)
+ var i = 0
+ while (i < n) {
+ val v = seq(i)
+ if (v == null) validity(i) = false
+ else { validity(i) = true; data(i) = v.asInstanceOf[Byte] }
+ i += 1
+ }
+ pr.addArrayInt8(idx, data, validity)
+
+ case Type.INT16 =>
+ val data = new Array[Short](n)
+ var i = 0
+ while (i < n) {
+ val v = seq(i)
+ if (v == null) validity(i) = false
+ else { validity(i) = true; data(i) = v.asInstanceOf[Short] }
+ i += 1
+ }
+ pr.addArrayInt16(idx, data, validity)
+
+ case Type.INT32 =>
+ val data = new Array[Int](n)
+ var i = 0
+ while (i < n) {
+ val v = seq(i)
+ if (v == null) validity(i) = false
+ else { validity(i) = true; data(i) = v.asInstanceOf[Int] }
+ i += 1
+ }
+ pr.addArrayInt32(idx, data, validity)
+
+ case Type.INT64 =>
+ val data = new Array[Long](n)
+ var i = 0
+ while (i < n) {
+ val v = seq(i)
+ if (v == null) validity(i) = false
+ else { validity(i) = true; data(i) = v.asInstanceOf[Long] }
+ i += 1
+ }
+ pr.addArrayInt64(idx, data, validity)
+
+ case Type.DOUBLE =>
+ val data = new Array[Double](n)
+ var i = 0
+ while (i < n) {
+ val v = seq(i)
+ if (v == null) validity(i) = false
+ else { validity(i) = true; data(i) = v.asInstanceOf[Double] }
+ i += 1
+ }
+ pr.addArrayDouble(idx, data, validity)
+
+ case Type.STRING | Type.VARCHAR =>
+ val data = new Array[String](n)
+ var i = 0
+ while (i < n) {
+ val v = seq(i)
+ if (v == null) validity(i) = false
+ else { validity(i) = true; data(i) = v.toString }
+ i += 1
+ }
+ pr.addArrayString(idx, data, validity)
+
+ case Type.BINARY =>
+ val data = new Array[Array[Byte]](n)
+ var i = 0
+ while (i < n) {
+ val v = seq(i)
+ if (v == null) validity(i) = false
+ else { validity(i) = true; data(i) = v.asInstanceOf[Array[Byte]] }
+ i += 1
+ }
+ pr.addArrayBinary(idx, data, validity)
+
+ case Type.FLOAT =>
+ val data = new Array[Float](n)
+ var i = 0
+ while (i < n) {
+ val v = seq(i)
+ if (v == null) validity(i) = false
+ else {
+ validity(i) = true; data(i) = v.asInstanceOf[Float]
+ }
+ i += 1
+ }
+ pr.addArrayFloat(idx, data, validity)
+
+ case Type.DATE =>
+ val data = new Array[java.sql.Date](n)
+ var i = 0
+ while (i < n) {
+ val v = seq(i)
+ if (v == null) validity(i) = false
+ else {
+ validity(i) = true
+ data(i) = v.asInstanceOf[java.sql.Date]
+ }
+ i += 1
+ }
+ pr.addArrayDate(idx, data, validity)
+
+ case Type.UNIXTIME_MICROS =>
+ val data = new Array[java.sql.Timestamp](n)
+ var i = 0
+ while (i < n) {
+ val v = seq(i)
+ if (v == null) validity(i) = false
+ else {
+ validity(i) = true
+ data(i) = v.asInstanceOf[java.sql.Timestamp]
+ }
+ i += 1
+ }
+ pr.addArrayTimestamp(idx, data, validity)
+
+ case Type.DECIMAL =>
+ val data = new Array[java.math.BigDecimal](n)
+ var i = 0
+ while (i < n) {
+ val v = seq(i)
+ if (v == null) validity(i) = false
+ else {
+ validity(i) = true;
+ data(i) = v.asInstanceOf[java.math.BigDecimal]
+ }
+ i += 1
+ }
+ pr.addArrayDecimal(idx, data, validity)
+
+ case t =>
+ throw new IllegalArgumentException(s"Unsupported Kudu array element
type $t")
+ }
+ }
}
diff --git
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/SparkUtil.scala
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/SparkUtil.scala
index bf34c9065..8b115c41e 100644
--- a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/SparkUtil.scala
+++ b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/SparkUtil.scala
@@ -25,7 +25,6 @@ import org.apache.yetus.audience.InterfaceStability
import org.apache.kudu.ColumnTypeAttributes.ColumnTypeAttributesBuilder
import org.apache.kudu.ColumnSchema
-import org.apache.kudu.ColumnTypeAttributes
import org.apache.kudu.Schema
import org.apache.kudu.Type
@@ -36,30 +35,49 @@ import scala.collection.JavaConverters._
object SparkUtil {
/**
- * Converts a Kudu [[Type]] to a Spark SQL [[DataType]].
+ * Converts a Kudu [[ColumnSchema]] to a Spark SQL [[DataType]].
*
- * @param t the Kudu type
- * @param a the Kudu type attributes
- * @return the corresponding Spark SQL type
+ * Handles both scalar and array (nested) columns.
+ *
+ * @param col the Kudu column schema
+ * @return the corresponding Spark SQL data type
*/
- def kuduTypeToSparkType(t: Type, a: ColumnTypeAttributes): DataType =
- t match {
- case Type.BOOL => BooleanType
- case Type.INT8 => ByteType
- case Type.INT16 => ShortType
- case Type.INT32 => IntegerType
- case Type.INT64 => LongType
- case Type.UNIXTIME_MICROS => TimestampType
- case Type.DATE => DateType
- case Type.FLOAT => FloatType
- case Type.DOUBLE => DoubleType
- case Type.VARCHAR => StringType
- case Type.STRING => StringType
- case Type.BINARY => BinaryType
- case Type.DECIMAL => DecimalType(a.getPrecision, a.getScale)
- case _ =>
- throw new IllegalArgumentException(s"No support for Kudu type $t")
+ def kuduTypeToSparkType(col: ColumnSchema): DataType = {
+ val t = col.getType
+ val a = col.getTypeAttributes
+
+ if (col.isArray) {
+ val elemType = col.getNestedTypeDescriptor.getArrayDescriptor.getElemType
+ val elemSparkType = kuduTypeToSparkType(
+ new ColumnSchema.ColumnSchemaBuilder("elem", elemType)
+ .typeAttributes(a)
+ .build()
+ )
+ ArrayType(elemSparkType, containsNull = true)
+ } else {
+ t match {
+ case Type.BOOL => BooleanType
+ case Type.INT8 => ByteType
+ case Type.INT16 => ShortType
+ case Type.INT32 => IntegerType
+ case Type.INT64 => LongType
+ case Type.UNIXTIME_MICROS => TimestampType
+ case Type.DATE => DateType
+ case Type.FLOAT => FloatType
+ case Type.DOUBLE => DoubleType
+ case Type.VARCHAR => StringType
+ case Type.STRING => StringType
+ case Type.BINARY => BinaryType
+ case Type.DECIMAL => DecimalType(a.getPrecision, a.getScale)
+ case Type.NESTED =>
+ throw new IllegalArgumentException(
+ "Type.NESTED should not be converted directly;" +
+ " handle via col.isArray/col.getElementType.")
+ case other =>
+ throw new IllegalArgumentException(s"No support for Kudu type
$other")
+ }
}
+ }
/**
* Converts a Spark SQL [[DataType]] to a Kudu [[Type]].
@@ -80,6 +98,7 @@ object SparkUtil {
case DataTypes.FloatType => Type.FLOAT
case DataTypes.DoubleType => Type.DOUBLE
case DecimalType() => Type.DECIMAL
+ case at: ArrayType => Type.NESTED
case _ =>
throw new IllegalArgumentException(s"No support for Spark SQL type $dt")
}
@@ -92,13 +111,12 @@ object SparkUtil {
* @return the SparkSQL schema
*/
def sparkSchema(kuduSchema: Schema, fields: Option[Seq[String]] = None):
StructType = {
- val kuduColumns = fields match {
+ val kuduColumns: Seq[ColumnSchema] = fields match {
case Some(fieldNames) => fieldNames.map(kuduSchema.getColumn)
case None => kuduSchema.getColumns.asScala
}
val sparkColumns = kuduColumns.map { col =>
- val sparkType = kuduTypeToSparkType(col.getType, col.getTypeAttributes)
- StructField(col.getName, sparkType, col.isNullable)
+ StructField(col.getName, kuduTypeToSparkType(col), col.isNullable)
}
StructType(sparkColumns)
}
@@ -115,6 +133,10 @@ object SparkUtil {
// add the key columns first, in the order specified
for (key <- keys) {
val field = sparkSchema.fields(sparkSchema.fieldIndex(key))
+ if (field.dataType.isInstanceOf[ArrayType]) {
+ throw new IllegalArgumentException(
+ s"Array-typed column '${field.name}' cannot be a Kudu key")
+ }
val col = createColumnSchema(field, isKey = true)
kuduCols.add(col)
}
@@ -134,21 +156,45 @@ object SparkUtil {
* @return the Kudu column schema
*/
private def createColumnSchema(field: StructField, isKey: Boolean):
ColumnSchema = {
- val kt = sparkTypeToKuduType(field.dataType)
- val col = new ColumnSchema.ColumnSchemaBuilder(field.name, kt)
- .key(isKey)
- .nullable(field.nullable)
- // Add ColumnTypeAttributesBuilder to DECIMAL columns
- if (kt == Type.DECIMAL) {
- val dt = field.dataType.asInstanceOf[DecimalType]
- col.typeAttributes(
- new ColumnTypeAttributesBuilder()
- .precision(dt.precision)
- .scale(dt.scale)
+ field.dataType match {
+ case at @ ArrayType(elemDt, containsNull) =>
+ val elemKt = sparkTypeToKuduType(elemDt)
+ val b = new ColumnSchema.ColumnSchemaBuilder(field.name, elemKt)
+ .key(isKey)
+ .nullable(field.nullable)
+ .array(true)
+
+ // DECIMAL element attributes, if needed
+ elemDt match {
+ case d: DecimalType =>
+ b.typeAttributes(
+ new ColumnTypeAttributesBuilder()
+ .precision(d.precision)
+ .scale(d.scale)
+ .build()
+ )
+ case _ => // noop
+ }
+ b.build()
+
+ case d: DecimalType =>
+ val b = new ColumnSchema.ColumnSchemaBuilder(field.name, Type.DECIMAL)
+ .key(isKey)
+ .nullable(field.nullable)
+ .typeAttributes(
+ new ColumnTypeAttributesBuilder()
+ .precision(d.precision)
+ .scale(d.scale)
+ .build()
+ )
+ b.build()
+
+ case other =>
+ val kt = sparkTypeToKuduType(other)
+ new ColumnSchema.ColumnSchemaBuilder(field.name, kt)
+ .key(isKey)
+ .nullable(field.nullable)
.build()
- )
}
- col.build()
}
-
}
diff --git
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
index a90565c2b..4dea01444 100644
---
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
+++
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
@@ -781,7 +781,7 @@ class DefaultSourceTest extends KuduTestSuite with Matchers
{
))
val dfDefaultSchema =
sqlContext.read.options(kuduOptions).format("kudu").load
- assertEquals(16, dfDefaultSchema.schema.fields.length)
+ assertEquals(30, dfDefaultSchema.schema.fields.length)
val dfWithUserSchema =
sqlContext.read.options(kuduOptions).schema(userSchema).format("kudu").load
diff --git
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala
index d70101632..715f51845 100644
---
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala
+++
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala
@@ -26,6 +26,11 @@ import java.nio.charset.StandardCharsets.UTF_8
import java.sql.Date
import java.sql.Timestamp
+import org.apache.spark.sql.types._
+import org.apache.kudu.client._
+import org.apache.kudu.Type
+import scala.collection.JavaConverters._
+
import org.apache.kudu.util.DateUtil
import org.apache.kudu.util.TimestampUtil
import org.apache.spark.sql.functions.decode
@@ -141,4 +146,154 @@ class KuduContextTest extends KuduTestSuite with Matchers
{
.get(0)
.shouldBe("bytes 0")
}
+
+ @Test
+ def testArrayColumnReadWrite(): Unit = {
+ insertRows(table, rowCount)
+
+ val dataDF = ss.read
+ .options(Map("kudu.master" -> harness.getMasterAddressesAsString,
"kudu.table" -> "test"))
+ .format("kudu")
+ .load
+
+ val sample = dataDF
+ .select(
+ "c16_bool_array",
+ "c17_int8_array",
+ "c18_int16_array",
+ "c19_int32_array",
+ "c20_int64_array",
+ "c21_float_array",
+ "c22_double_array",
+ "c23_date_array",
+ "c24_unixtime_array",
+ "c25_string_array",
+ "c26_varchar_array",
+ "c27_binary_array",
+ "c28_decimal_array"
+ )
+ .limit(1)
+ .collect()(0)
+
+ // Note: getAs[] will automatically convert nulls inside the Seq as null
elements.
+
+ val bools = sample.getAs[Seq[Boolean]]("c16_bool_array")
+ val bytes = sample.getAs[Seq[Byte]]("c17_int8_array")
+ val shorts = sample.getAs[Seq[Short]]("c18_int16_array")
+ val ints = sample.getAs[Seq[Int]]("c19_int32_array")
+ val longs = sample.getAs[Seq[Long]]("c20_int64_array")
+ val floats = sample.getAs[Seq[Float]]("c21_float_array")
+ val doubles = sample.getAs[Seq[Double]]("c22_double_array")
+ val dates = sample.getAs[Seq[java.sql.Date]]("c23_date_array")
+ val timestamps =
sample.getAs[Seq[java.sql.Timestamp]]("c24_unixtime_array")
+ val strings = sample.getAs[Seq[String]]("c25_string_array")
+ val varchars = sample.getAs[Seq[String]]("c26_varchar_array")
+ val binaries = sample.getAs[Seq[Array[Byte]]]("c27_binary_array")
+ val decimals = sample.getAs[Seq[java.math.BigDecimal]]("c28_decimal_array")
+
+ // Validate structure (size, presence)
+ val allArrays = Seq(
+ "bools" -> bools,
+ "bytes" -> bytes,
+ "shorts" -> shorts,
+ "ints" -> ints,
+ "longs" -> longs,
+ "floats" -> floats,
+ "doubles" -> doubles,
+ "dates" -> dates,
+ "timestamps" -> timestamps,
+ "strings" -> strings,
+ "varchars" -> varchars,
+ "binaries" -> binaries,
+ "decimals" -> decimals
+ )
+
+ allArrays.foreach {
+ case (name, arr) =>
+ assert(arr != null, s"$name array should not be null")
+ assert(arr.nonEmpty, s"$name array should not be empty")
+ assert(arr.size == 3, s"$name array should contain at least two
elements")
+ }
+
+ // Verify expected head values or pattern
+ assert(bools.head)
+ assert(bytes.head == 0)
+ assert(shorts.head == 0)
+ assert(ints.head == 0)
+ assert(longs.head == 0)
+ assert(math.abs(floats.head - 0.0f) < 1e-6)
+ assert(math.abs(doubles.head - 0.0) < 1e-12)
+ assert(strings.head.startsWith("val-"))
+ assert(varchars.head.startsWith("vchar-"))
+ assert(new String(binaries.head, UTF_8).startsWith("bin-"))
+ assert(decimals.head.compareTo(new java.math.BigDecimal("0.00")) == 0)
+
+ // Check null propagation: middle element was intentionally null (validity
false)
+ allArrays.foreach {
+ case (name, arr) =>
+ if (arr.size > 1)
+ assert(arr(1) == null, s"Middle element of $name should be null
(validity=false)")
+ }
+
+ // Write back one row to ensure insert works
+ val newDF = dataDF.limit(1).withColumn("key", dataDF("key") + 100)
+ kuduContext.insertRows(newDF, "test")
+
+ val checkDF = ss.read
+ .options(Map("kudu.master" -> harness.getMasterAddressesAsString,
"kudu.table" -> "test"))
+ .format("kudu")
+ .load
+
+ assert(checkDF.filter("key = 100").count() == 1)
+ }
+
+ @Test
+ def testSchemaDriftWithNewArrayColumn(): Unit = {
+ val spark = ss
+ import spark.implicits._
+
+ val driftTable = "test_schema_drift_array"
+
+ // Create table using Spark StructType
+ val baseStruct = new StructType()
+ .add("id", IntegerType, nullable = false)
+ .add("name", StringType, nullable = true)
+
+ val createOpts = new CreateTableOptions()
+ .setRangePartitionColumns(List("id").asJava)
+ .setNumReplicas(1)
+
+ // use the KuduContext overload that accepts a Spark StructType
+ kuduContext.createTable(driftTable, baseStruct, keys = Seq("id"),
createOpts)
+
+ // Seed some rows (no array column yet)
+ val df1 = Seq((1, "foo"), (2, "bar")).toDF("id", "name")
+ kuduContext.insertRows(df1, driftTable)
+
+ // Evolve the DF schema by adding a new ARRAY<STRING> column
+ val df2 = Seq(
+ (3, "baz", Seq("alpha", "beta")),
+ (4, "qux", Seq("gamma", "delta"))
+ ).toDF("id", "name", "tags")
+
+ // Write the rows with SchemaDrift enabled
+ val writeOpts = KuduWriteOptions(handleSchemaDrift = true)
+ kuduContext.insertRows(df2, driftTable, writeOpts)
+
+ // Verify schema drift materialized the array column
+ val kuduSchema = kuduContext.syncClient.openTable(driftTable).getSchema
+ val tagsCol = kuduSchema.getColumn("tags")
+ assert(tagsCol != null, "Schema drift should have added 'tags'")
+ assert(tagsCol.isArray, "tags should be an array column")
+ assert(tagsCol.getType == Type.NESTED, "array columns are stored as
Type.NESTED")
+
+ // Verify data round-trip
+ val readDf = spark.read
+ .options(Map("kudu.master" -> harness.getMasterAddressesAsString,
"kudu.table" -> driftTable))
+ .format("kudu")
+ .load
+
+ val rows = readDf.filter("id >= 3").orderBy("id").collect()
+ assert(rows.head.getAs[Seq[String]]("tags") == Seq("alpha", "beta"))
+ }
}
diff --git
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala
index 6f5c4fe43..d2c2306f7 100644
---
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala
+++
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala
@@ -96,7 +96,70 @@ trait KuduTestSuite {
.typeAttributes(CharUtil.typeAttributes(CharUtil.MAX_VARCHAR_LENGTH))
.nullable(true)
.build(),
- new ColumnSchemaBuilder("c15_date", Type.DATE).build()
+ new ColumnSchemaBuilder("c15_date", Type.DATE).build(),
+ // ===== ARRAY TYPES =====
+ new ColumnSchemaBuilder("c16_bool_array", Type.BOOL)
+ .nullable(true)
+ .array(true)
+ .build(),
+ new ColumnSchemaBuilder("c17_int8_array", Type.INT8)
+ .nullable(true)
+ .array(true)
+ .build(),
+ new ColumnSchemaBuilder("c18_int16_array", Type.INT16)
+ .nullable(true)
+ .array(true)
+ .build(),
+ new ColumnSchemaBuilder("c19_int32_array", Type.INT32)
+ .nullable(true)
+ .array(true)
+ .build(),
+ new ColumnSchemaBuilder("c20_int64_array", Type.INT64)
+ .nullable(true)
+ .array(true)
+ .build(),
+ new ColumnSchemaBuilder("c21_float_array", Type.FLOAT)
+ .nullable(true)
+ .array(true)
+ .build(),
+ new ColumnSchemaBuilder("c22_double_array", Type.DOUBLE)
+ .nullable(true)
+ .array(true)
+ .build(),
+ new ColumnSchemaBuilder("c23_date_array", Type.DATE)
+ .nullable(true)
+ .array(true)
+ .build(),
+ new ColumnSchemaBuilder("c24_unixtime_array", Type.UNIXTIME_MICROS)
+ .nullable(true)
+ .array(true)
+ .build(),
+ new ColumnSchemaBuilder("c25_string_array", Type.STRING)
+ .nullable(true)
+ .array(true)
+ .build(),
+ new ColumnSchemaBuilder("c26_varchar_array", Type.VARCHAR)
+ .typeAttributes(CharUtil.typeAttributes(CharUtil.MAX_VARCHAR_LENGTH))
+ .nullable(true)
+ .array(true)
+ .build(),
+ new ColumnSchemaBuilder("c27_binary_array", Type.BINARY)
+ .nullable(true)
+ .array(true)
+ .build(),
+ new ColumnSchemaBuilder("c28_decimal_array", Type.DECIMAL)
+ .typeAttributes(
+ new ColumnTypeAttributesBuilder()
+ .precision(10)
+ .scale(2)
+ .build()
+ )
+ .nullable(true)
+ .array(true)
+ .build(),
+ new ColumnSchemaBuilder("c29_non_nullable_string_array", Type.STRING)
+ .array(true)
+ .build()
)
new Schema(columns.asJava)
}
@@ -268,6 +331,50 @@ trait KuduTestSuite {
row.addDecimal(13, BigDecimal.valueOf(i))
row.addDate(15, DateUtil.epochDaysToSqlDate(i))
+ val bools = Array(true, i % 2 == 0, false)
+ row.addArrayBool(16, bools, Array(true, false, true))
+ val bytes = Array(i.toByte, (i + 1).toByte, (i + 2).toByte)
+ row.addArrayInt8(17, bytes, Array(true, false, true))
+ val shorts = Array(i.toShort, (i + 1).toShort, (i + 2).toShort)
+ row.addArrayInt16(18, shorts, Array(true, false, true))
+ val ints = Array(i, i + 1, i + 2)
+ row.addArrayInt32(19, ints, Array(true, false, true))
+ val longs = Array(i.toLong, i.toLong + 10, i.toLong + 20)
+ row.addArrayInt64(20, longs, Array(true, false, true))
+ val floats = Array(i.toFloat, i.toFloat + 0.5f, i.toFloat + 1.0f)
+ row.addArrayFloat(21, floats, Array(true, false, true))
+ val doubles = Array(i.toDouble, i.toDouble + 0.5, i.toDouble + 1.0)
+ row.addArrayDouble(22, doubles, Array(true, false, true))
+ val dates = Array(
+ DateUtil.epochDaysToSqlDate(i),
+ DateUtil.epochDaysToSqlDate(i + 1),
+ DateUtil.epochDaysToSqlDate(i + 2)
+ )
+ row.addArrayDate(23, dates, Array(true, false, true))
+ val timestamps = Array(
+ new java.sql.Timestamp(ts / 1000),
+ new java.sql.Timestamp(ts / 1000 + 1000000),
+ new java.sql.Timestamp(ts / 1000 + 2000000)
+ )
+ row.addArrayTimestamp(24, timestamps, Array(true, false, true))
+ val strings = Array(s"val-$i", s"val-${i + 1}", s"val-${i + 2}")
+ row.addArrayString(25, strings, Array(true, false, true))
+ val varchars = Array(s"vchar-$i", s"vchar-${i + 1}", s"vchar-${i + 2}")
+ row.addArrayVarchar(26, varchars, Array(true, false, true))
+ val binaries = Array(
+ s"bin-$i".getBytes(UTF_8),
+ s"bin-${i + 1}".getBytes(UTF_8),
+ s"bin-${i + 2}".getBytes(UTF_8)
+ )
+ row.addArrayBinary(27, binaries, Array(true, false, true))
+ val decimals = Array(
+ BigDecimal.valueOf(i, 2),
+ BigDecimal.valueOf(i + 1, 2),
+ BigDecimal.valueOf(i + 2, 2)
+ )
+ row.addArrayDecimal(28, decimals, Array(true, false, true))
+ row.addArrayString(29, strings, Array(true, true, true))
+
// Sprinkling some nulls so that queries see them.
val s = if (i % 2 == 0) {
row.addString(2, i.toString)
@@ -317,6 +424,50 @@ trait KuduTestSuite {
row.addVarchar(14, i.toString)
row.addDate(15, DateUtil.epochDaysToSqlDate(i))
+ val bools = Array(true, i % 2 == 0, false)
+ row.addArrayBool(16, bools, Array(true, false, true))
+ val bytes = Array(i.toByte, (i + 1).toByte, (i + 2).toByte)
+ row.addArrayInt8(17, bytes, Array(true, false, true))
+ val shorts = Array(i.toShort, (i + 1).toShort, (i + 2).toShort)
+ row.addArrayInt16(18, shorts, Array(true, false, true))
+ val ints = Array(i, i + 1, i + 2)
+ row.addArrayInt32(19, ints, Array(true, false, true))
+ val longs = Array(i.toLong, i.toLong + 10, i.toLong + 20)
+ row.addArrayInt64(20, longs, Array(true, false, true))
+ val floats = Array(i.toFloat, i.toFloat + 0.5f, i.toFloat + 1.0f)
+ row.addArrayFloat(21, floats, Array(true, false, true))
+ val doubles = Array(i.toDouble, i.toDouble + 0.5, i.toDouble + 1.0)
+ row.addArrayDouble(22, doubles, Array(true, false, true))
+ val dates = Array(
+ DateUtil.epochDaysToSqlDate(i),
+ DateUtil.epochDaysToSqlDate(i + 1),
+ DateUtil.epochDaysToSqlDate(i + 2)
+ )
+ row.addArrayDate(23, dates, Array(true, false, true))
+ val timestamps = Array(
+ new java.sql.Timestamp(ts / 1000),
+ new java.sql.Timestamp(ts / 1000 + 1000000),
+ new java.sql.Timestamp(ts / 1000 + 2000000)
+ )
+ row.addArrayTimestamp(24, timestamps, Array(true, false, true))
+ val strings = Array(s"val-$i", s"val-${i + 1}", s"val-${i + 2}")
+ row.addArrayString(25, strings, Array(true, false, true))
+ val varchars = Array(s"vchar-$i", s"vchar-${i + 1}", s"vchar-${i + 2}")
+ row.addArrayVarchar(26, varchars, Array(true, false, true))
+ val binaries = Array(
+ s"bin-$i".getBytes(UTF_8),
+ s"bin-${i + 1}".getBytes(UTF_8),
+ s"bin-${i + 2}".getBytes(UTF_8)
+ )
+ row.addArrayBinary(27, binaries, Array(true, false, true))
+ val decimals = Array(
+ BigDecimal.valueOf(i, 2),
+ BigDecimal.valueOf(i + 1, 2),
+ BigDecimal.valueOf(i + 2, 2)
+ )
+ row.addArrayDecimal(28, decimals, Array(true, false, true))
+ row.addArrayString(29, strings, Array(true, true, true))
+
// Sprinkling some nulls so that queries see them.
val s = if (i % 2 == 0) {
row.addString(2, i.toString)
diff --git
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/SparkSQLTest.scala
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/SparkSQLTest.scala
index 654f167bd..0642ac26c 100644
---
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/SparkSQLTest.scala
+++
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/SparkSQLTest.scala
@@ -17,6 +17,7 @@
package org.apache.kudu.spark.kudu
+import java.nio.charset.StandardCharsets.UTF_8
import scala.collection.JavaConverters._
import scala.collection.immutable.IndexedSeq
import scala.util.control.NonFatal
@@ -531,4 +532,120 @@ class SparkSQLTest extends KuduTestSuite with Matchers {
val results = sqlContext.sql(sqlStr).collectAsList()
assert(results.size() == rowCount)
}
+
+ @Test
+ def testSelectArrayColumns(): Unit = {
+ val spark = ss
+ import spark.implicits._
+
+ val base = sqlContext
+ .sql(s"SELECT key, c19_int32_array, c25_string_array FROM $tableName
ORDER BY key")
+ .collectAsList()
+
+ assert(base.size() == rowCount)
+
+ val first = base.get(0)
+ val intArray = first.getAs[Seq[Integer]]("c19_int32_array")
+ val strArray = first.getAs[Seq[String]]("c25_string_array")
+
+ // existing coverage: arrays with a null in the middle
+ assert(intArray == Seq(Integer.valueOf(0), null, Integer.valueOf(2)))
+ assert(strArray == Seq("val-0", null, "val-2"))
+ assert(intArray(1) == null)
+ assert(strArray(1) == null)
+
+ // Insert extra test rows for edge cases: empty arrays + null arrays
+ // Provide all NOT NULL fields.
+ val specialDF = Seq(
+ // Row with empty arrays
+ (
+ 999, // key
+ 999, // c1_i
+ "dummy", // c2_s
+ 1.0, // c3_double
+ 1L, // c4_long
+ true, // c5_bool
+ 1.toShort, // c6_short
+ 1.0f, // c7_float
+ "bytes-999".getBytes(UTF_8), // c8_binary
+ System.currentTimeMillis() * 1000, // c9_unixtime_micros
+ 1.toByte, // c10_byte
+ BigDecimal.valueOf(1),
+ BigDecimal.valueOf(1),
+ BigDecimal.valueOf(1), // decimals
+ "varchar-999", // c14_varchar
+ java.sql.Date.valueOf("2020-01-01"), // c15_date
+ Seq.empty[Int], // c19_int32_array
+ Seq.empty[String], // c25_string_array
+ Seq.empty[String] //c29_string_array
+ ),
+ // Row with null arrays (except c29_non_nullable_string_array column)
+ (
+ 1000,
+ 1000,
+ "dummy",
+ 2.0,
+ 2L,
+ false,
+ 2.toShort,
+ 2.0f,
+ "bytes-1000".getBytes(UTF_8),
+ System.currentTimeMillis() * 1000,
+ 2.toByte,
+ BigDecimal.valueOf(2),
+ BigDecimal.valueOf(2),
+ BigDecimal.valueOf(2),
+ "varchar-1000",
+ java.sql.Date.valueOf("2020-01-01"),
+ null.asInstanceOf[Seq[Int]],
+ null.asInstanceOf[Seq[String]],
+ Seq.empty[String]
+ )
+ ).toDF(
+ "key",
+ "c1_i",
+ "c2_s",
+ "c3_double",
+ "c4_long",
+ "c5_bool",
+ "c6_short",
+ "c7_float",
+ "c8_binary",
+ "c9_unixtime_micros",
+ "c10_byte",
+ "c11_decimal32",
+ "c12_decimal64",
+ "c13_decimal128",
+ "c14_varchar",
+ "c15_date",
+ "c19_int32_array",
+ "c25_string_array",
+ "c29_non_nullable_string_array"
+ )
+
+ kuduContext.insertRows(specialDF, tableName)
+
+ // Re-read and verify edge cases
+ val checkDF = spark.read
+ .options(Map("kudu.master" -> harness.getMasterAddressesAsString,
"kudu.table" -> tableName))
+ .format("kudu")
+ .load()
+ .filter("key >= 999")
+ .orderBy("key")
+ .collect()
+
+ // Row 999: empty arrays
+ val rowEmpty = checkDF(0)
+ val emptyInts = rowEmpty.getAs[Seq[Integer]]("c19_int32_array")
+ val emptyStrs = rowEmpty.getAs[Seq[String]]("c25_string_array")
+ assert(emptyInts != null && emptyInts.isEmpty, "Empty int array should
return Seq()")
+ assert(emptyStrs != null && emptyStrs.isEmpty, "Empty string array should
return Seq()")
+
+ // Row 1000: null arrays
+ val rowNull = checkDF(1)
+ val nullInts = rowNull.getAs[Seq[Integer]]("c19_int32_array")
+ val nullStrs = rowNull.getAs[Seq[String]]("c25_string_array")
+ assert(nullInts == null, "Null int array cell should map to null")
+ assert(nullStrs == null, "Null string array cell should map to null")
+ }
}