dongjoon-hyun commented on a change in pull request #32764:
URL: https://github.com/apache/spark/pull/32764#discussion_r645136124
##########
File path:
sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
##########
@@ -117,6 +120,8 @@
* <li>{@link org.apache.spark.sql.types.DateType}: {@code int}</li>
* <li>{@link org.apache.spark.sql.types.TimestampType}: {@code long}</li>
* <li>{@link org.apache.spark.sql.types.BinaryType}: {@code byte[]}</li>
+ * <li>{@link org.apache.spark.sql.types.CalendarIntervalType}:
Review comment:
I'm fine. :)
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -21,15 +21,15 @@ import java.util
import java.util.Collections
import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage,
JavaLongAdd, JavaStrLen}
-import
test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.JavaLongAddMagic
+import
test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.{JavaLongAddDefault,
JavaLongAddMagic, JavaLongAddMismatchMagic, JavaLongAddStaticMagic}
import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.InternalRow
import
org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode.{FALLBACK,
NO_CODEGEN}
import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog,
Identifier, InMemoryCatalog, SupportsNamespaces}
-import org.apache.spark.sql.connector.catalog.functions._
+import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction, _}
Review comment:
?
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -312,6 +344,27 @@ class DataSourceV2FunctionSuite extends
DatasourceV2SQLBase {
}
}
+ test("SPARK-35390: aggregate function w/ type coercion") {
+ import testImplicits._
+
+ withTable("t1", "t2") {
+ addFunction(Identifier.of(Array("ns"), "avg"), UnboundDecimalAverage)
+
+ (1 to 100).toDF().write.saveAsTable("testcat.ns.t1")
+ checkAnswer(sql(s"SELECT testcat.ns.avg(value) from testcat.ns.t1"),
Review comment:
nit. `s"` -> `"`.
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -312,6 +344,27 @@ class DataSourceV2FunctionSuite extends
DatasourceV2SQLBase {
}
}
+ test("SPARK-35390: aggregate function w/ type coercion") {
+ import testImplicits._
+
+ withTable("t1", "t2") {
+ addFunction(Identifier.of(Array("ns"), "avg"), UnboundDecimalAverage)
+
+ (1 to 100).toDF().write.saveAsTable("testcat.ns.t1")
+ checkAnswer(sql(s"SELECT testcat.ns.avg(value) from testcat.ns.t1"),
Review comment:
And, ditto for line 358 and 362.
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -466,6 +526,60 @@ class DataSourceV2FunctionSuite extends
DatasourceV2SQLBase {
override def produceResult(state: (Long, Long)): Long = state._1 / state._2
}
+ object UnboundDecimalAverage extends UnboundFunction {
+ override def name(): String = "favg"
+
+ override def bind(inputType: StructType): BoundFunction = {
+ if (inputType.fields.length > 1) {
+ throw new UnsupportedOperationException("Too many arguments")
+ }
+
+ // put interval type here for testing purpose
+ inputType.fields(0).dataType match {
+ case _: NumericType | _: DayTimeIntervalType => DecimalAverage
+ case dataType =>
+ throw new UnsupportedOperationException(s"Unsupported input type:
$dataType")
+ }
+ }
+
+ override def description(): String =
+ """iavg: produces an average using decimal division, ignoring nulls
Review comment:
Just a question. What is the meaning of `i` here? Previously, I thought
it stands for `Integer`.
BTW, `davg` may look a little confusing because it can be read as
DecimalAverage or DoubleAverage.
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -466,6 +526,60 @@ class DataSourceV2FunctionSuite extends
DatasourceV2SQLBase {
override def produceResult(state: (Long, Long)): Long = state._1 / state._2
}
+ object UnboundDecimalAverage extends UnboundFunction {
+ override def name(): String = "favg"
+
+ override def bind(inputType: StructType): BoundFunction = {
+ if (inputType.fields.length > 1) {
+ throw new UnsupportedOperationException("Too many arguments")
+ }
+
+ // put interval type here for testing purpose
+ inputType.fields(0).dataType match {
+ case _: NumericType | _: DayTimeIntervalType => DecimalAverage
+ case dataType =>
+ throw new UnsupportedOperationException(s"Unsupported input type:
$dataType")
+ }
+ }
+
+ override def description(): String =
+ """iavg: produces an average using decimal division, ignoring nulls
Review comment:
Just a question. What is the meaning of `i` here? Previously, I thought
it stands for `Integer`. At line 556, `davg` is used though.
```
override def name(): String = "davg"
```
BTW, `davg` may look a little confusing because it can be read as
DecimalAverage or DoubleAverage.
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -466,6 +526,60 @@ class DataSourceV2FunctionSuite extends
DatasourceV2SQLBase {
override def produceResult(state: (Long, Long)): Long = state._1 / state._2
}
+ object UnboundDecimalAverage extends UnboundFunction {
+ override def name(): String = "favg"
+
+ override def bind(inputType: StructType): BoundFunction = {
+ if (inputType.fields.length > 1) {
+ throw new UnsupportedOperationException("Too many arguments")
+ }
+
+ // put interval type here for testing purpose
+ inputType.fields(0).dataType match {
+ case _: NumericType | _: DayTimeIntervalType => DecimalAverage
+ case dataType =>
+ throw new UnsupportedOperationException(s"Unsupported input type:
$dataType")
+ }
+ }
+
+ override def description(): String =
+ """iavg: produces an average using decimal division, ignoring nulls
+ | iavg(integral) -> decimal
+ | iavg(float) -> decimal
+ | iavg(decimal) -> decimal""".stripMargin
+ }
+
+ object DecimalAverage extends AggregateFunction[(Decimal, Int), Decimal] {
+ val PRECISION: Int = 15
+ val SCALE: Int = 5
Review comment:
Sorry, but may I ask where these (precision 15, scale 5) came from? I
might miss the context, but this looks like not non-conventional values in
Apache Spark codebase. Is it intentionally chosen for some purpose?
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -466,6 +526,60 @@ class DataSourceV2FunctionSuite extends
DatasourceV2SQLBase {
override def produceResult(state: (Long, Long)): Long = state._1 / state._2
}
+ object UnboundDecimalAverage extends UnboundFunction {
+ override def name(): String = "favg"
+
+ override def bind(inputType: StructType): BoundFunction = {
+ if (inputType.fields.length > 1) {
+ throw new UnsupportedOperationException("Too many arguments")
+ }
+
+ // put interval type here for testing purpose
+ inputType.fields(0).dataType match {
+ case _: NumericType | _: DayTimeIntervalType => DecimalAverage
+ case dataType =>
+ throw new UnsupportedOperationException(s"Unsupported input type:
$dataType")
+ }
+ }
+
+ override def description(): String =
+ """iavg: produces an average using decimal division, ignoring nulls
+ | iavg(integral) -> decimal
+ | iavg(float) -> decimal
+ | iavg(decimal) -> decimal""".stripMargin
+ }
+
+ object DecimalAverage extends AggregateFunction[(Decimal, Int), Decimal] {
+ val PRECISION: Int = 15
+ val SCALE: Int = 5
Review comment:
Sorry, but may I ask where these (precision 15, scale 5) came from? I
might miss the context, but this looks like non-conventional values in Apache
Spark codebase. Is it intentionally chosen for some purpose?
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -2169,12 +2169,29 @@ class Analyzer(override val catalogManager:
CatalogManager)
unbound, arguments, unsupported)
}
+ if (bound.inputTypes().length != arguments.length) {
+ throw
QueryCompilationErrors.v2FunctionInvalidInputTypeLengthError(
+ bound, arguments)
+ }
+
+ val castedArguments = arguments.zip(bound.inputTypes()).map
{ case (arg, ty) =>
+ if (arg.dataType != ty) {
+ if (Cast.canCast(arg.dataType, ty)) {
+ Cast(arg, ty)
+ } else {
+ throw
QueryCompilationErrors.v2FunctionCastError(bound, arg, ty)
+ }
+ } else {
+ arg
+ }
+ }
Review comment:
Since we have two places of type coercion, is there any possibility of
bug where we do differently in some cases at both places?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]