dongjoon-hyun commented on a change in pull request #32764:
URL: https://github.com/apache/spark/pull/32764#discussion_r645136124



##########
File path: 
sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
##########
@@ -117,6 +120,8 @@
  *   <li>{@link org.apache.spark.sql.types.DateType}: {@code int}</li>
  *   <li>{@link org.apache.spark.sql.types.TimestampType}: {@code long}</li>
  *   <li>{@link org.apache.spark.sql.types.BinaryType}: {@code byte[]}</li>
+ *   <li>{@link org.apache.spark.sql.types.CalendarIntervalType}:

Review comment:
       I'm fine. :)

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -21,15 +21,15 @@ import java.util
 import java.util.Collections
 
 import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, 
JavaLongAdd, JavaStrLen}
-import 
test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.JavaLongAddMagic
+import 
test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.{JavaLongAddDefault,
 JavaLongAddMagic, JavaLongAddMismatchMagic, JavaLongAddStaticMagic}
 import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.{AnalysisException, Row}
 import org.apache.spark.sql.catalyst.InternalRow
 import 
org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode.{FALLBACK, 
NO_CODEGEN}
 import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, 
Identifier, InMemoryCatalog, SupportsNamespaces}
-import org.apache.spark.sql.connector.catalog.functions._
+import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction, _}

Review comment:
       ?

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -312,6 +344,27 @@ class DataSourceV2FunctionSuite extends 
DatasourceV2SQLBase {
     }
   }
 
+  test("SPARK-35390: aggregate function w/ type coercion") {
+    import testImplicits._
+
+    withTable("t1", "t2") {
+      addFunction(Identifier.of(Array("ns"), "avg"), UnboundDecimalAverage)
+
+      (1 to 100).toDF().write.saveAsTable("testcat.ns.t1")
+      checkAnswer(sql(s"SELECT testcat.ns.avg(value) from testcat.ns.t1"),

Review comment:
       nit. `s"` -> `"`.

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -312,6 +344,27 @@ class DataSourceV2FunctionSuite extends 
DatasourceV2SQLBase {
     }
   }
 
+  test("SPARK-35390: aggregate function w/ type coercion") {
+    import testImplicits._
+
+    withTable("t1", "t2") {
+      addFunction(Identifier.of(Array("ns"), "avg"), UnboundDecimalAverage)
+
+      (1 to 100).toDF().write.saveAsTable("testcat.ns.t1")
+      checkAnswer(sql(s"SELECT testcat.ns.avg(value) from testcat.ns.t1"),

Review comment:
       And, ditto for line 358 and 362.

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -466,6 +526,60 @@ class DataSourceV2FunctionSuite extends 
DatasourceV2SQLBase {
     override def produceResult(state: (Long, Long)): Long = state._1 / state._2
   }
 
+  object UnboundDecimalAverage extends UnboundFunction {
+    override def name(): String = "favg"
+
+    override def bind(inputType: StructType): BoundFunction = {
+      if (inputType.fields.length > 1) {
+        throw new UnsupportedOperationException("Too many arguments")
+      }
+
+      // put interval type here for testing purpose
+      inputType.fields(0).dataType match {
+        case _: NumericType | _: DayTimeIntervalType => DecimalAverage
+        case dataType =>
+          throw new UnsupportedOperationException(s"Unsupported input type: 
$dataType")
+      }
+    }
+
+    override def description(): String =
+      """iavg: produces an average using decimal division, ignoring nulls

Review comment:
       Just a question. What is the meaning of `i` here? Previously, I thought 
it stands for `Integer`.
   BTW, `davg` may look a little confusing because it can be read as 
DecimalAverage or DoubleAverage.

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -466,6 +526,60 @@ class DataSourceV2FunctionSuite extends 
DatasourceV2SQLBase {
     override def produceResult(state: (Long, Long)): Long = state._1 / state._2
   }
 
+  object UnboundDecimalAverage extends UnboundFunction {
+    override def name(): String = "favg"
+
+    override def bind(inputType: StructType): BoundFunction = {
+      if (inputType.fields.length > 1) {
+        throw new UnsupportedOperationException("Too many arguments")
+      }
+
+      // put interval type here for testing purpose
+      inputType.fields(0).dataType match {
+        case _: NumericType | _: DayTimeIntervalType => DecimalAverage
+        case dataType =>
+          throw new UnsupportedOperationException(s"Unsupported input type: 
$dataType")
+      }
+    }
+
+    override def description(): String =
+      """iavg: produces an average using decimal division, ignoring nulls

Review comment:
       Just a question. What is the meaning of `i` here? Previously, I thought 
it stands for `Integer`. At line 556, `davg` is used though.
   ```
   override def name(): String = "davg"
   ```
   BTW, `davg` may look a little confusing because it can be read as 
DecimalAverage or DoubleAverage.

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -466,6 +526,60 @@ class DataSourceV2FunctionSuite extends 
DatasourceV2SQLBase {
     override def produceResult(state: (Long, Long)): Long = state._1 / state._2
   }
 
+  object UnboundDecimalAverage extends UnboundFunction {
+    override def name(): String = "favg"
+
+    override def bind(inputType: StructType): BoundFunction = {
+      if (inputType.fields.length > 1) {
+        throw new UnsupportedOperationException("Too many arguments")
+      }
+
+      // put interval type here for testing purpose
+      inputType.fields(0).dataType match {
+        case _: NumericType | _: DayTimeIntervalType => DecimalAverage
+        case dataType =>
+          throw new UnsupportedOperationException(s"Unsupported input type: 
$dataType")
+      }
+    }
+
+    override def description(): String =
+      """iavg: produces an average using decimal division, ignoring nulls
+        |  iavg(integral) -> decimal
+        |  iavg(float) -> decimal
+        |  iavg(decimal) -> decimal""".stripMargin
+  }
+
+  object DecimalAverage extends AggregateFunction[(Decimal, Int), Decimal] {
+    val PRECISION: Int = 15
+    val SCALE: Int = 5

Review comment:
       Sorry, but may I ask where these (precision 15, scale 5) came from? I 
might miss the context, but this looks like not non-conventional values in 
Apache Spark codebase. Is it intentionally chosen for some purpose?

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -466,6 +526,60 @@ class DataSourceV2FunctionSuite extends 
DatasourceV2SQLBase {
     override def produceResult(state: (Long, Long)): Long = state._1 / state._2
   }
 
+  object UnboundDecimalAverage extends UnboundFunction {
+    override def name(): String = "favg"
+
+    override def bind(inputType: StructType): BoundFunction = {
+      if (inputType.fields.length > 1) {
+        throw new UnsupportedOperationException("Too many arguments")
+      }
+
+      // put interval type here for testing purpose
+      inputType.fields(0).dataType match {
+        case _: NumericType | _: DayTimeIntervalType => DecimalAverage
+        case dataType =>
+          throw new UnsupportedOperationException(s"Unsupported input type: 
$dataType")
+      }
+    }
+
+    override def description(): String =
+      """iavg: produces an average using decimal division, ignoring nulls
+        |  iavg(integral) -> decimal
+        |  iavg(float) -> decimal
+        |  iavg(decimal) -> decimal""".stripMargin
+  }
+
+  object DecimalAverage extends AggregateFunction[(Decimal, Int), Decimal] {
+    val PRECISION: Int = 15
+    val SCALE: Int = 5

Review comment:
       Sorry, but may I ask where these (precision 15, scale 5) came from? I 
might miss the context, but this looks like non-conventional values in Apache 
Spark codebase. Is it intentionally chosen for some purpose?

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -2169,12 +2169,29 @@ class Analyzer(override val catalogManager: 
CatalogManager)
                         unbound, arguments, unsupported)
                   }
 
+                  if (bound.inputTypes().length != arguments.length) {
+                    throw 
QueryCompilationErrors.v2FunctionInvalidInputTypeLengthError(
+                      bound, arguments)
+                  }
+
+                  val castedArguments = arguments.zip(bound.inputTypes()).map 
{ case (arg, ty) =>
+                    if (arg.dataType != ty) {
+                      if (Cast.canCast(arg.dataType, ty)) {
+                        Cast(arg, ty)
+                      } else {
+                        throw 
QueryCompilationErrors.v2FunctionCastError(bound, arg, ty)
+                      }
+                    } else {
+                      arg
+                    }
+                  }

Review comment:
       Since we have two places of type coercion, is there any possibility of 
bug where we do differently in some cases at both places?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to