sunchao commented on a change in pull request #32764:
URL: https://github.com/apache/spark/pull/32764#discussion_r644507916



##########
File path: 
sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
##########
@@ -117,6 +120,8 @@
  *   <li>{@link org.apache.spark.sql.types.DateType}: {@code int}</li>
  *   <li>{@link org.apache.spark.sql.types.TimestampType}: {@code long}</li>
  *   <li>{@link org.apache.spark.sql.types.BinaryType}: {@code byte[]}</li>
+ *   <li>{@link org.apache.spark.sql.types.CalendarIntervalType}:

Review comment:
       this is not related - lemme know if I should use a separate PR.

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -2169,12 +2169,29 @@ class Analyzer(override val catalogManager: 
CatalogManager)
                         unbound, arguments, unsupported)
                   }
 
+                  if (bound.inputTypes().length != arguments.length) {
+                    throw 
QueryCompilationErrors.v2FunctionInvalidInputTypeLengthError(
+                      bound, arguments)
+                  }
+
+                  val castedArguments = arguments.zip(bound.inputTypes()).map 
{ case (arg, ty) =>
+                    if (arg.dataType != ty) {
+                      if (Cast.canCast(arg.dataType, ty)) {
+                        Cast(arg, ty)
+                      } else {
+                        throw 
QueryCompilationErrors.v2FunctionCastError(bound, arg, ty)
+                      }
+                    } else {
+                      arg
+                    }
+                  }

Review comment:
       Yes `bind` checks input types, and potentially allows multiple 
combinations of them. However, when evaluating the UDF, especially when dealing 
with magic method, Spark only accept a single set of input parameter types, and 
so we'll need to insert cast if necessary.
   
   For instance, a UDF can accept both `int` and `decimal` as input types in 
`bind`, and implements magic method using `decimal` parameters. Spark then 
should cast `int` arguments to `decimal` when necessary.

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -466,6 +526,60 @@ class DataSourceV2FunctionSuite extends 
DatasourceV2SQLBase {
     override def produceResult(state: (Long, Long)): Long = state._1 / state._2
   }
 
+  object UnboundDecimalAverage extends UnboundFunction {
+    override def name(): String = "favg"
+
+    override def bind(inputType: StructType): BoundFunction = {
+      if (inputType.fields.length > 1) {
+        throw new UnsupportedOperationException("Too many arguments")
+      }
+
+      // put interval type here for testing purpose
+      inputType.fields(0).dataType match {
+        case _: NumericType | _: DayTimeIntervalType => DecimalAverage
+        case dataType =>
+          throw new UnsupportedOperationException(s"Unsupported input type: 
$dataType")
+      }
+    }
+
+    override def description(): String =
+      """iavg: produces an average using decimal division, ignoring nulls

Review comment:
       my bad - let me fix the name to be something more meaningful, like 
`decimal_avg`.

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -466,6 +526,60 @@ class DataSourceV2FunctionSuite extends 
DatasourceV2SQLBase {
     override def produceResult(state: (Long, Long)): Long = state._1 / state._2
   }
 
+  object UnboundDecimalAverage extends UnboundFunction {
+    override def name(): String = "favg"
+
+    override def bind(inputType: StructType): BoundFunction = {
+      if (inputType.fields.length > 1) {
+        throw new UnsupportedOperationException("Too many arguments")
+      }
+
+      // put interval type here for testing purpose
+      inputType.fields(0).dataType match {
+        case _: NumericType | _: DayTimeIntervalType => DecimalAverage
+        case dataType =>
+          throw new UnsupportedOperationException(s"Unsupported input type: 
$dataType")
+      }
+    }
+
+    override def description(): String =
+      """iavg: produces an average using decimal division, ignoring nulls
+        |  iavg(integral) -> decimal
+        |  iavg(float) -> decimal
+        |  iavg(decimal) -> decimal""".stripMargin
+  }
+
+  object DecimalAverage extends AggregateFunction[(Decimal, Int), Decimal] {
+    val PRECISION: Int = 15
+    val SCALE: Int = 5

Review comment:
       sorry this is a little confusing - I was using this to test some 
negative case like passing `long` which is `decimal(20, 0)` to this 
`decimal(15, 5)` which won't work without losing precision. But, I no longer 
need this now. Will remove.

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -2169,12 +2169,29 @@ class Analyzer(override val catalogManager: 
CatalogManager)
                         unbound, arguments, unsupported)
                   }
 
+                  if (bound.inputTypes().length != arguments.length) {
+                    throw 
QueryCompilationErrors.v2FunctionInvalidInputTypeLengthError(
+                      bound, arguments)
+                  }
+
+                  val castedArguments = arguments.zip(bound.inputTypes()).map 
{ case (arg, ty) =>
+                    if (arg.dataType != ty) {
+                      if (Cast.canCast(arg.dataType, ty)) {
+                        Cast(arg, ty)
+                      } else {
+                        throw 
QueryCompilationErrors.v2FunctionCastError(bound, arg, ty)
+                      }
+                    } else {
+                      arg
+                    }
+                  }

Review comment:
       > Since we have two places of type coercion, is there any possibility of 
bug where we do differently in some cases at both places?
   
   Hmm @dongjoon-hyun what are the two places of type coercion? do you mean the 
potential inconsistency from `bind` and `inputTypes`?
   
   > Thanks for explaining it. So that is said, bind can return a 
implementation with magic method, which takes decimal input, when Spark binds 
it with IntegerType input?
   
   Yes correct. Spark will insert cast `int -> decimal`. 
   
   Also from the Java doc of `inputTypes`: "If the types returned differ from 
the types passed to `bind(StructType)`, Spark will cast input values to the 
required data types. This allows implementations to delegate input value 
casting to Spark."
   
   > Using above example, when Spark binds it with IntegerType, the UDF must 
know Spark can cast int to decimal, so it can return an implementation with 
magic method taking decimal input?
   
   Yes the implementor of the UDF should know whether such a cast is valid. 
Spark also checks this during analysis and throws error if not.
   

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -2169,12 +2169,29 @@ class Analyzer(override val catalogManager: 
CatalogManager)
                         unbound, arguments, unsupported)
                   }
 
+                  if (bound.inputTypes().length != arguments.length) {
+                    throw 
QueryCompilationErrors.v2FunctionInvalidInputTypeLengthError(
+                      bound, arguments)
+                  }
+
+                  val castedArguments = arguments.zip(bound.inputTypes()).map 
{ case (arg, ty) =>
+                    if (arg.dataType != ty) {
+                      if (Cast.canCast(arg.dataType, ty)) {
+                        Cast(arg, ty)
+                      } else {
+                        throw 
QueryCompilationErrors.v2FunctionCastError(bound, arg, ty)
+                      }
+                    } else {
+                      arg
+                    }
+                  }

Review comment:
       Yes `bind` checks input types, and allows multiple combinations of them. 
However, when evaluating the UDF, especially when dealing with magic method, 
Spark only accept a single set of input parameter types, and so we'll need to 
insert cast if necessary.
   
   For instance, a UDF can accept both `int` and `decimal` as input types in 
`bind`, and implements magic method using `decimal` parameters. Spark then 
should cast `int` arguments to `decimal` when necessary.

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
##########
@@ -21,15 +21,15 @@ import java.util
 import java.util.Collections
 
 import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, 
JavaLongAdd, JavaStrLen}
-import 
test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.JavaLongAddMagic
+import 
test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.{JavaLongAddDefault,
 JavaLongAddMagic, JavaLongAddMismatchMagic, JavaLongAddStaticMagic}
 import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.{AnalysisException, Row}
 import org.apache.spark.sql.catalyst.InternalRow
 import 
org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode.{FALLBACK, 
NO_CODEGEN}
 import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, 
Identifier, InMemoryCatalog, SupportsNamespaces}
-import org.apache.spark.sql.connector.catalog.functions._
+import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction, _}

Review comment:
       You mean to split into two lines?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to