Bruce Robbins created SPARK-38146: ------------------------------------- Summary: UDAF fails with unsafe rows containing a TIMESTAMP_NTZ column Key: SPARK-38146 URL: https://issues.apache.org/jira/browse/SPARK-38146 Project: Spark Issue Type: Bug Components: SQL Affects Versions: 3.3.0 Reporter: Bruce Robbins
When using a UDAF against unsafe rows containing a TIMESTAMP_NTZ column, Spark throws the error: {noformat} 22/02/08 18:05:12 ERROR Executor: Exception in task 0.0 in stage 0.0 (TID 0) java.lang.UnsupportedOperationException: null at org.apache.spark.sql.catalyst.expressions.UnsafeRow.update(UnsafeRow.java:218) ~[spark-catalyst_2.12-3.3.0-SNAPSHOT.jar:3.3.0-SNAPSHOT] at org.apache.spark.sql.execution.aggregate.BufferSetterGetterUtils.$anonfun$createSetters$15(udaf.scala:217) ~[spark-sql_2.12-3.3.0-SNAPSHOT.jar:3.3.0-SNAPSHOT] at org.apache.spark.sql.execution.aggregate.BufferSetterGetterUtils.$anonfun$createSetters$15$adapted(udaf.scala:215) ~[spark-sql_2.12-3.3.0-SNAPSHOT.jar:3.3.0-SNAPSHOT] at org.apache.spark.sql.execution.aggregate.MutableAggregationBufferImpl.update(udaf.scala:272) ~[spark-sql_2.12-3.3.0-SNAPSHOT.jar:3.3.0-SNAPSHOT] at $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$ScalaAggregateFunction.$anonfun$update$1(<console>:46) ~[scala-library.jar:?] at scala.collection.immutable.Range.foreach$mVc$sp(Range.scala:158) ~[scala-library.jar:?] at $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$ScalaAggregateFunction.update(<console>:45) ~[scala-library.jar:?] at org.apache.spark.sql.execution.aggregate.ScalaUDAF.update(udaf.scala:458) ~[spark-sql_2.12-3.3.0-SNAPSHOT.jar:3.3.0-SNAPSHOT] at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1.$anonfun$applyOrElse$2(AggregationIterator.scala:197) ~[spark-sql_2.12-3.3.0-SNAPSHO {noformat} This is because {{BufferSetterGetterUtils#createSetters}} does not have a case statement for {{TimestampNTZType}}, so it generates a function that tries to call {{UnsafeRow.update}}, which throws an {{UnsupportedOperationException}}. This reproduction example is mostly taken from {{AggregationQuerySuite}}: {noformat} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ import org.apache.spark.sql.Row class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction { def inputSchema: StructType = schema def bufferSchema: StructType = schema def dataType: DataType = schema def deterministic: Boolean = true def initialize(buffer: MutableAggregationBuffer): Unit = { (0 until schema.length).foreach { i => buffer.update(i, null) } } def update(buffer: MutableAggregationBuffer, input: Row): Unit = { if (!input.isNullAt(0) && input.getInt(0) == 50) { (0 until schema.length).foreach { i => buffer.update(i, input.get(i)) } } } def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { if (!buffer2.isNullAt(0) && buffer2.getInt(0) == 50) { (0 until schema.length).foreach { i => buffer1.update(i, buffer2.get(i)) } } } def evaluate(buffer: Row): Any = { Row.fromSeq(buffer.toSeq) } } import scala.util.Random import java.time.LocalDateTime val r = new Random(65676563L) val data = Seq.tabulate(50) { x => Row((x + 1).toInt, (x + 2).toDouble, (x + 2).toLong, LocalDateTime.parse("2100-01-01T01:33:33.123").minusDays(x + 1)) } val schema = StructType.fromDDL("id int, col1 double, col2 bigint, col3 timestamp_ntz") val rdd = spark.sparkContext.parallelize(data, 1) val df = spark.createDataFrame(rdd, schema) val udaf = new ScalaAggregateFunction(df.schema) val allColumns = df.schema.fields.map(f => col(f.name)) df.groupBy().agg(udaf(allColumns: _*)).show(false) {noformat} -- This message was sent by Atlassian Jira (v8.20.1#820001) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org