sunchao commented on a change in pull request #32764:
URL: https://github.com/apache/spark/pull/32764#discussion_r644507916
##########
File path:
sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
##########
@@ -117,6 +120,8 @@
* <li>{@link org.apache.spark.sql.types.DateType}: {@code int}</li>
* <li>{@link org.apache.spark.sql.types.TimestampType}: {@code long}</li>
* <li>{@link org.apache.spark.sql.types.BinaryType}: {@code byte[]}</li>
+ * <li>{@link org.apache.spark.sql.types.CalendarIntervalType}:
Review comment:
this is not related - lemme know if I should use a separate PR.
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -2169,12 +2169,29 @@ class Analyzer(override val catalogManager:
CatalogManager)
unbound, arguments, unsupported)
}
+ if (bound.inputTypes().length != arguments.length) {
+ throw
QueryCompilationErrors.v2FunctionInvalidInputTypeLengthError(
+ bound, arguments)
+ }
+
+ val castedArguments = arguments.zip(bound.inputTypes()).map
{ case (arg, ty) =>
+ if (arg.dataType != ty) {
+ if (Cast.canCast(arg.dataType, ty)) {
+ Cast(arg, ty)
+ } else {
+ throw
QueryCompilationErrors.v2FunctionCastError(bound, arg, ty)
+ }
+ } else {
+ arg
+ }
+ }
Review comment:
Yes `bind` checks input types, and potentially allows multiple
combinations of them. However, when evaluating the UDF, especially when dealing
with magic method, Spark only accept a single set of input parameter types, and
so we'll need to insert cast if necessary.
For instance, a UDF can accept both `int` and `decimal` as input types in
`bind`, and implements magic method using `decimal` parameters. Spark then
should cast `int` arguments to `decimal` when necessary.
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -466,6 +526,60 @@ class DataSourceV2FunctionSuite extends
DatasourceV2SQLBase {
override def produceResult(state: (Long, Long)): Long = state._1 / state._2
}
+ object UnboundDecimalAverage extends UnboundFunction {
+ override def name(): String = "favg"
+
+ override def bind(inputType: StructType): BoundFunction = {
+ if (inputType.fields.length > 1) {
+ throw new UnsupportedOperationException("Too many arguments")
+ }
+
+ // put interval type here for testing purpose
+ inputType.fields(0).dataType match {
+ case _: NumericType | _: DayTimeIntervalType => DecimalAverage
+ case dataType =>
+ throw new UnsupportedOperationException(s"Unsupported input type:
$dataType")
+ }
+ }
+
+ override def description(): String =
+ """iavg: produces an average using decimal division, ignoring nulls
Review comment:
my bad - let me fix the name to be something more meaningful, like
`decimal_avg`.
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -466,6 +526,60 @@ class DataSourceV2FunctionSuite extends
DatasourceV2SQLBase {
override def produceResult(state: (Long, Long)): Long = state._1 / state._2
}
+ object UnboundDecimalAverage extends UnboundFunction {
+ override def name(): String = "favg"
+
+ override def bind(inputType: StructType): BoundFunction = {
+ if (inputType.fields.length > 1) {
+ throw new UnsupportedOperationException("Too many arguments")
+ }
+
+ // put interval type here for testing purpose
+ inputType.fields(0).dataType match {
+ case _: NumericType | _: DayTimeIntervalType => DecimalAverage
+ case dataType =>
+ throw new UnsupportedOperationException(s"Unsupported input type:
$dataType")
+ }
+ }
+
+ override def description(): String =
+ """iavg: produces an average using decimal division, ignoring nulls
+ | iavg(integral) -> decimal
+ | iavg(float) -> decimal
+ | iavg(decimal) -> decimal""".stripMargin
+ }
+
+ object DecimalAverage extends AggregateFunction[(Decimal, Int), Decimal] {
+ val PRECISION: Int = 15
+ val SCALE: Int = 5
Review comment:
sorry this is a little confusing - I was using this to test some
negative case like passing `long` which is `decimal(20, 0)` to this
`decimal(15, 5)` which won't work without losing precision. But, I no longer
need this now. Will remove.
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -2169,12 +2169,29 @@ class Analyzer(override val catalogManager:
CatalogManager)
unbound, arguments, unsupported)
}
+ if (bound.inputTypes().length != arguments.length) {
+ throw
QueryCompilationErrors.v2FunctionInvalidInputTypeLengthError(
+ bound, arguments)
+ }
+
+ val castedArguments = arguments.zip(bound.inputTypes()).map
{ case (arg, ty) =>
+ if (arg.dataType != ty) {
+ if (Cast.canCast(arg.dataType, ty)) {
+ Cast(arg, ty)
+ } else {
+ throw
QueryCompilationErrors.v2FunctionCastError(bound, arg, ty)
+ }
+ } else {
+ arg
+ }
+ }
Review comment:
> Since we have two places of type coercion, is there any possibility of
bug where we do differently in some cases at both places?
Hmm @dongjoon-hyun what are the two places of type coercion? do you mean the
potential inconsistency from `bind` and `inputTypes`?
> Thanks for explaining it. So that is said, bind can return a
implementation with magic method, which takes decimal input, when Spark binds
it with IntegerType input?
Yes correct. Spark will insert cast `int -> decimal`.
Also from the Java doc of `inputTypes`: "If the types returned differ from
the types passed to `bind(StructType)`, Spark will cast input values to the
required data types. This allows implementations to delegate input value
casting to Spark."
> Using above example, when Spark binds it with IntegerType, the UDF must
know Spark can cast int to decimal, so it can return an implementation with
magic method taking decimal input?
Yes the implementor of the UDF should know whether such a cast is valid.
Spark also checks this during analysis and throws error if not.
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -2169,12 +2169,29 @@ class Analyzer(override val catalogManager:
CatalogManager)
unbound, arguments, unsupported)
}
+ if (bound.inputTypes().length != arguments.length) {
+ throw
QueryCompilationErrors.v2FunctionInvalidInputTypeLengthError(
+ bound, arguments)
+ }
+
+ val castedArguments = arguments.zip(bound.inputTypes()).map
{ case (arg, ty) =>
+ if (arg.dataType != ty) {
+ if (Cast.canCast(arg.dataType, ty)) {
+ Cast(arg, ty)
+ } else {
+ throw
QueryCompilationErrors.v2FunctionCastError(bound, arg, ty)
+ }
+ } else {
+ arg
+ }
+ }
Review comment:
Yes `bind` checks input types, and allows multiple combinations of them.
However, when evaluating the UDF, especially when dealing with magic method,
Spark only accept a single set of input parameter types, and so we'll need to
insert cast if necessary.
For instance, a UDF can accept both `int` and `decimal` as input types in
`bind`, and implements magic method using `decimal` parameters. Spark then
should cast `int` arguments to `decimal` when necessary.
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -21,15 +21,15 @@ import java.util
import java.util.Collections
import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage,
JavaLongAdd, JavaStrLen}
-import
test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.JavaLongAddMagic
+import
test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.{JavaLongAddDefault,
JavaLongAddMagic, JavaLongAddMismatchMagic, JavaLongAddStaticMagic}
import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.InternalRow
import
org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode.{FALLBACK,
NO_CODEGEN}
import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog,
Identifier, InMemoryCatalog, SupportsNamespaces}
-import org.apache.spark.sql.connector.catalog.functions._
+import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction, _}
Review comment:
You mean to split into two lines?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]