cloud-fan commented on a change in pull request #27937: [SPARK-30127][SQL] 
Support case class parameter for typed Scala UDF
URL: https://github.com/apache/spark/pull/27937#discussion_r395489257
 
 

 ##########
 File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
 ##########
 @@ -48,25 +46,87 @@ case class ScalaUDF(
     function: AnyRef,
     dataType: DataType,
     children: Seq[Expression],
-    inputPrimitives: Seq[Boolean],
-    inputTypes: Seq[AbstractDataType] = Nil,
+    inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil,
     udfName: Option[String] = None,
     nullable: Boolean = true,
     udfDeterministic: Boolean = true)
   extends Expression with NonSQLExpression with UserDefinedExpression {
 
   override lazy val deterministic: Boolean = udfDeterministic && 
children.forall(_.deterministic)
 
+  private lazy val resolvedEnc = mutable.HashMap[Int, ExpressionEncoder[_]]()
+
   override def toString: String = 
s"${udfName.getOrElse("UDF")}(${children.mkString(", ")})"
 
+  /**
+   * The analyzer should be aware of Scala primitive types so as to make the
+   * UDF return null if there is any null input value of these types. On the
+   * other hand, Java UDFs can only have boxed types, thus this parameter will
+   * always be all false.
+   */
+  def inputPrimitives: Seq[Boolean] = {
+    inputEncoders.map { encoderOpt =>
+      // It's possible that some of the inputs don't have a specific 
encoder(e.g. `Any`)
+      if (encoderOpt.isDefined) {
+        val encoder = encoderOpt.get
+        if (encoder.isSerializedAsStruct) {
+          // struct type is not primitive
+          false
+        } else {
+          // `nullable` is false iff the type is primitive
+          !encoder.schema.head.nullable
+        }
+      } else {
+        // Any type is not primitive
+        false
+      }
+    }
+  }
+
+  /**
+   * The expected input types of this UDF, used to perform type coercion. If 
we do
+   * not want to perform coercion, simply use "Nil". Note that it would've been
+   * better to use Option of Seq[DataType] so we can use "None" as the case 
for no
+   * type coercion. However, that would require more refactoring of the 
codebase.
+   */
+  def inputTypes: Seq[AbstractDataType] = {
+    inputEncoders.map { encoderOpt =>
+      if (encoderOpt.isDefined) {
+        val encoder = encoderOpt.get
+        if (encoder.isSerializedAsStruct) {
+          encoder.schema
+        } else {
+          encoder.schema.head.dataType
+        }
+      } else {
+        AnyDataType
+      }
+    }
+  }
+
+  private def createToScalaConverter(i: Int, dataType: DataType): Any => Any = 
{
+    if (inputEncoders.isEmpty) {
+      // for untyped Scala UDF
+      CatalystTypeConverters.createToScalaConverter(dataType)
+    } else {
+      val encoder = inputEncoders(i)
+      if (encoder.isDefined && encoder.get.isSerializedAsStructForTopLevel) {
+        val enc = resolvedEnc.getOrElseUpdate(i, encoder.get.resolveAndBind())
 
 Review comment:
   why we need `resolvedEnc`? I think we can simply write
   ```
   val enc = encoder.get.resolveAndBind()
   row: Any => enc.fromRow(row.asInstanceOf[InternalRow])
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to