Bruce Robbins created SPARK-52738:
-------------------------------------
Summary: Support aggregating TIME type with a UDAF when the
underlying buffer is an UnsafeRow
Key: SPARK-52738
URL: https://issues.apache.org/jira/browse/SPARK-52738
Project: Spark
Issue Type: Sub-task
Components: SQL
Affects Versions: 4.1.0
Reporter: Bruce Robbins
Spark gets an error while aggregating a TIME type with a UDAF when the
underlying aggregation buffer is an unsafe row (i.e., when all fields in the
schema are considered mutable by {{UnsafeRow}}).
Assume this code:
{noformat}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer,
UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
class ScalaAggregateFunction(schema: StructType) extends
UserDefinedAggregateFunction {
def inputSchema: StructType = schema
def bufferSchema: StructType = schema
def dataType: DataType = schema
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
(0 until schema.length).foreach { i =>
buffer.update(i, null)
}
}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0) && input.getInt(0) == 50) {
(0 until schema.length).foreach { i =>
buffer.update(i, input.get(i))
}
}
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
if (!buffer2.isNullAt(0) && buffer2.getInt(0) == 50) {
(0 until schema.length).foreach { i =>
buffer1.update(i, buffer2.get(i))
}
}
}
def evaluate(buffer: Row): Any = {
Row.fromSeq(buffer.toSeq)
}
}
import scala.util.Random
import java.time.LocalTime
val r = new Random(65676563L)
val data = Seq.tabulate(50) { x =>
Row((x + 1).toInt, (x + 2).toDouble, (x + 2).toLong,
LocalTime.parse("23:33:33.123").minusMinutes(x % 1300 + 1))
}
val schema = StructType.fromDDL("id int, col1 double, col2 bigint, col3 time")
val rdd = spark.sparkContext.parallelize(data, 1)
val df = spark.createDataFrame(rdd, schema)
val udaf = new ScalaAggregateFunction(df.schema)
val allColumns = df.schema.fields.map(f => col(f.name))
df.groupBy().agg(udaf(allColumns: _*)).show(false)
{noformat}
It gets this error:
{noformat}
warning: 1 deprecation (since 2.13.0); for details, enable `:setting
-deprecation` or `:replay -deprecation`
Exception in task 0.0 in stage 0.0 (TID 0)
org.apache.spark.SparkUnsupportedOperationException:
[UNSUPPORTED_CALL.WITHOUT_SUGGESTION] Cannot call the method "update" of the
class "org.apache.spark.sql.catalyst.expressions.UnsafeRow". SQLSTATE: 0A000
{noformat}
--
This message was sent by Atlassian Jira
(v8.20.10#820010)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]