asfgit closed pull request #23275: [SPARK-26323][SQL] Scala UDF should still
check input types even if some inputs are of type Any
URL: https://github.com/apache/spark/pull/23275
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index b19aa50ba2156..13cc9b9c125e9 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -882,7 +882,18 @@ object TypeCoercion {
case udf: ScalaUDF if udf.inputTypes.nonEmpty =>
val children = udf.children.zip(udf.inputTypes).map { case (in,
expected) =>
- implicitCast(in, udfInputToCastType(in.dataType,
expected)).getOrElse(in)
+ // Currently Scala UDF will only expect `AnyDataType` at top level,
so this trick works.
+ // In the future we should create types like `AbstractArrayType`, so
that Scala UDF can
+ // accept inputs of array type of arbitrary element type.
+ if (expected == AnyDataType) {
+ in
+ } else {
+ implicitCast(
+ in,
+ udfInputToCastType(in.dataType, expected.asInstanceOf[DataType])
+ ).getOrElse(in)
+ }
+
}
udf.withNewChildren(children)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index a23aaa3a0b3ef..fae1119c394b4 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow,
ScalaReflection}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types.{AbstractDataType, DataType}
/**
* User-defined function.
@@ -48,7 +48,7 @@ case class ScalaUDF(
dataType: DataType,
children: Seq[Expression],
inputsNullSafe: Seq[Boolean],
- inputTypes: Seq[DataType] = Nil,
+ inputTypes: Seq[AbstractDataType] = Nil,
udfName: Option[String] = None,
nullable: Boolean = true,
udfDeterministic: Boolean = true)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index 5367ce2af8e9f..d2ef08873187e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -96,7 +96,7 @@ private[sql] object TypeCollection {
/**
* An `AbstractDataType` that matches any concrete data types.
*/
-protected[sql] object AnyDataType extends AbstractDataType {
+protected[sql] object AnyDataType extends AbstractDataType with Serializable {
// Note that since AnyDataType matches any concrete types,
defaultConcreteType should never
// be invoked.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index 5a3f556c9c074..fe5d1afd8478a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -123,17 +123,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
|def register[$typeTags](name: String, func: Function$x[$types]):
UserDefinedFunction = {
| val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
| val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
$inputSchemas
+ | val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ | val finalUdf = if (nullable) udf else udf.asNonNullable()
| def builder(e: Seq[Expression]) = if (e.length == $x) {
- | ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- | if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- | Some(name), nullable, udfDeterministic = true)
+ | finalUdf.createScalaUDF(e)
| } else {
| throw new AnalysisException("Invalid number of arguments for
function " + name +
| ". Expected: $x; Found: " + e.length)
| }
| functionRegistry.createOrReplaceTempFunction(name, builder)
- | val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- | if (nullable) udf else udf.asNonNullable()
+ | finalUdf
|}""".stripMargin)
}
@@ -170,17 +169,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag](name: String, func: Function0[RT]):
UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 0) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 0; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -191,17 +189,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1,
RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 1) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 1; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -212,17 +209,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func:
Function2[A1, A2, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 2) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 2; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -233,17 +229,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name:
String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 3) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 3; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -254,17 +249,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4:
TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]):
UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption ::
Try(ScalaReflection.schemaFor[A4]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 4) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 4; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -275,17 +269,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4:
TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]):
UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption ::
Try(ScalaReflection.schemaFor[A4]).toOption ::
Try(ScalaReflection.schemaFor[A5]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 5) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 5; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -296,17 +289,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4:
TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3,
A4, A5, A6, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption ::
Try(ScalaReflection.schemaFor[A4]).toOption ::
Try(ScalaReflection.schemaFor[A5]).toOption ::
Try(ScalaReflection.schemaFor[A6]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 6) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 6; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -317,17 +309,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4:
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func:
Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption ::
Try(ScalaReflection.schemaFor[A4]).toOption ::
Try(ScalaReflection.schemaFor[A5]).toOption ::
Try(ScalaReflection.schemaFor[A6]).toOption ::
Try(ScalaReflection.schemaFor[A7]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 7) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 7; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -338,17 +329,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4:
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String,
func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption ::
Try(ScalaReflection.schemaFor[A4]).toOption ::
Try(ScalaReflection.schemaFor[A5]).toOption ::
Try(ScalaReflection.schemaFor[A6]).toOption ::
Try(ScalaReflection.schemaFor[A7]).toOption ::
Try(ScalaReflection.schemaFor[A8]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 8) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 8; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -359,17 +349,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4:
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name:
String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]):
UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption ::
Try(ScalaReflection.schemaFor[A4]).toOption ::
Try(ScalaReflection.schemaFor[A5]).toOption ::
Try(ScalaReflection.schemaFor[A6]).toOption ::
Try(ScalaReflection.schemaFor[A7]).toOption ::
Try(ScalaReflection.schemaFor[A8]).toOption ::
Try(ScalaReflection.schemaFor[A9]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 9) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 9; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -380,17 +369,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4:
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10:
TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9,
A10, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption ::
Try(ScalaReflection.schemaFor[A4]).toOption ::
Try(ScalaReflection.schemaFor[A5]).toOption ::
Try(ScalaReflection.schemaFor[A6]).toOption ::
Try(ScalaReflection.schemaFor[A7]).toOption ::
Try(ScalaReflection.schemaFor[A8]).toOption ::
Try(ScalaReflection.schemaFor[A9]).toOption ::
Try(ScalaReflection.schemaFor[A10]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 10) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 10; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -401,17 +389,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4:
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10:
TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6,
A7, A8, A9, A10, A11, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption ::
Try(ScalaReflection.schemaFor[A4]).toOption ::
Try(ScalaReflection.schemaFor[A5]).toOption ::
Try(ScalaReflection.schemaFor[A6]).toOption ::
Try(ScalaReflection.schemaFor[A7]).toOption ::
Try(ScalaReflection.schemaFor[A8]).toOption ::
Try(ScalaReflection.schemaFor[A9]).toOption ::
Try(ScalaReflection.schemaFor[A10]).toOption ::
Try(ScalaReflection.schemaFor[A11]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 11) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 11; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -422,17 +409,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4:
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10:
TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3,
A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption ::
Try(ScalaReflection.schemaFor[A4]).toOption ::
Try(ScalaReflection.schemaFor[A5]).toOption ::
Try(ScalaReflection.schemaFor[A6]).toOption ::
Try(ScalaReflection.schemaFor[A7]).toOption ::
Try(ScalaReflection.schemaFor[A8]).toOption ::
Try(ScalaReflection.schemaFor[A9]).toOption ::
Try(ScalaReflection.schemaFor[A10]).toOption ::
Try(ScalaReflection.schemaFor[A11]).toOption ::
Try(ScalaReflection.schemaFor[A12]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 12) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 12; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -443,17 +429,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4:
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10:
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func:
Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]):
UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption ::
Try(ScalaReflection.schemaFor[A4]).toOption ::
Try(ScalaReflection.schemaFor[A5]).toOption ::
Try(ScalaReflection.schemaFor[A6]).toOption ::
Try(ScalaReflection.schemaFor[A7]).toOption ::
Try(ScalaReflection.schemaFor[A8]).toOption ::
Try(ScalaReflection.schemaFor[A9]).toOption ::
Try(ScalaReflection.schemaFor[A10]).toOption ::
Try(ScalaReflection.schemaFor[A11]).toOption ::
Try(ScalaReflection.schemaFor[A12]).toOption ::
Try(ScalaReflection.schemaFor[A13]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 13) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 13; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -464,17 +449,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4:
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10:
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String,
func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14,
RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption ::
Try(ScalaReflection.schemaFor[A4]).toOption ::
Try(ScalaReflection.schemaFor[A5]).toOption ::
Try(ScalaReflection.schemaFor[A6]).toOption ::
Try(ScalaReflection.schemaFor[A7]).toOption ::
Try(ScalaReflection.schemaFor[A8]).toOption ::
Try(ScalaReflection.schemaFor[A9]).toOption ::
Try(ScalaReflection.schemaFor[A10]).toOption ::
Try(ScalaReflection.schemaFor[A11]).toOption ::
Try(ScalaReflection.schemaFor[A12]).toOption ::
Try(ScalaReflection.schemaFor[A13]).toOption ::
Try(ScalaReflection.schemaFor[A14]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 14) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 14; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -485,17 +469,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4:
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10:
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15:
TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9,
A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption ::
Try(ScalaReflection.schemaFor[A4]).toOption ::
Try(ScalaReflection.schemaFor[A5]).toOption ::
Try(ScalaReflection.schemaFor[A6]).toOption ::
Try(ScalaReflection.schemaFor[A7]).toOption ::
Try(ScalaReflection.schemaFor[A8]).toOption ::
Try(ScalaReflection.schemaFor[A9]).toOption ::
Try(ScalaReflection.schemaFor[A10]).toOption ::
Try(ScalaReflection.schemaFor[A11]).toOption ::
Try(ScalaReflection.schemaFor[A12]).toOption ::
Try(ScalaReflection.schemaFor[A13]).toOption ::
Try(ScalaReflection.schemaFor[A14]).toOption ::
Try(ScalaReflection.schemaFor[A15]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 15) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 15; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -506,17 +489,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4:
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10:
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag,
A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8,
A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption ::
Try(ScalaReflection.schemaFor[A4]).toOption ::
Try(ScalaReflection.schemaFor[A5]).toOption ::
Try(ScalaReflection.schemaFor[A6]).toOption ::
Try(ScalaReflection.schemaFor[A7]).toOption ::
Try(ScalaReflection.schemaFor[A8]).toOption ::
Try(ScalaReflection.schemaFor[A9]).toOption ::
Try(ScalaReflection.schemaFor[A10]).toOption ::
Try(ScalaReflection.schemaFor[A11]).toOption ::
Try(ScalaReflection.schemaFor[A12]).toOption ::
Try(ScalaReflection.schemaFor[A13]).toOption ::
Try(ScalaReflection.schemaFor[A14]).toOption ::
Try(ScalaReflection.schemaFor[A15]).toOption ::
Try(ScalaReflection.schemaFor[A16]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 16) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 16; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -527,17 +509,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4:
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10:
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag,
A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5,
A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]):
UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption ::
Try(ScalaReflection.schemaFor[A4]).toOption ::
Try(ScalaReflection.schemaFor[A5]).toOption ::
Try(ScalaReflection.schemaFor[A6]).toOption ::
Try(ScalaReflection.schemaFor[A7]).toOption ::
Try(ScalaReflection.schemaFor[A8]).toOption ::
Try(ScalaReflection.schemaFor[A9]).toOption ::
Try(ScalaReflection.schemaFor[A10]).toOption ::
Try(ScalaReflection.schemaFor[A11]).toOption ::
Try(ScalaReflection.schemaFor[A12]).toOption ::
Try(ScalaReflection.schemaFor[A13]).toOption ::
Try(ScalaReflection.schemaFor[A14]).toOption ::
Try(ScalaReflection.schemaFor[A15]).toOption ::
Try(ScalaReflection.schemaFor[A16]).toOption ::
Try(ScalaReflection.schemaFor[A17]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 17) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 17; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -548,17 +529,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4:
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10:
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag,
A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1,
A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18,
RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption ::
Try(ScalaReflection.schemaFor[A4]).toOption ::
Try(ScalaReflection.schemaFor[A5]).toOption ::
Try(ScalaReflection.schemaFor[A6]).toOption ::
Try(ScalaReflection.schemaFor[A7]).toOption ::
Try(ScalaReflection.schemaFor[A8]).toOption ::
Try(ScalaReflection.schemaFor[A9]).toOption ::
Try(ScalaReflection.schemaFor[A10]).toOption ::
Try(ScalaReflection.schemaFor[A11]).toOption ::
Try(ScalaReflection.schemaFor[A12]).toOption ::
Try(ScalaReflection.schemaFor[A13]).toOption ::
Try(ScalaReflection.schemaFor[A14]).toOption ::
Try(ScalaReflection.schemaFor[A15]).toOption ::
Try(ScalaReflection.schemaFor[A16]).toOption ::
Try(ScalaReflection.schemaFor[A17]).toOption ::
Try(ScalaReflection.schemaFor[A18]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 18) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 18; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -569,17 +549,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4:
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10:
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag,
A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func:
Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15,
A16, A17, A18, A19, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption ::
Try(ScalaReflection.schemaFor[A4]).toOption ::
Try(ScalaReflection.schemaFor[A5]).toOption ::
Try(ScalaReflection.schemaFor[A6]).toOption ::
Try(ScalaReflection.schemaFor[A7]).toOption ::
Try(ScalaReflection.schemaFor[A8]).toOption ::
Try(ScalaReflection.schemaFor[A9]).toOption ::
Try(ScalaReflection.schemaFor[A10]).toOption ::
Try(ScalaReflection.schemaFor[A11]).toOption ::
Try(ScalaReflection.schemaFor[A12]).toOption ::
Try(ScalaReflection.schemaFor[A13]).toOption ::
Try(ScalaReflection.schemaFor[A14]).toOption ::
Try(ScalaReflection.schemaFor[A15]).toOption ::
Try(ScalaReflection.schemaFor[A16]).toOption ::
Try(ScalaReflection.schemaFor[A17]).toOption ::
Try(ScalaReflection.schemaFor[A18]).toOption ::
Try(ScalaReflection.schemaFor[A19]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 19) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 19; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -590,17 +569,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4:
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10:
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag,
A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name:
String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12,
A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption ::
Try(ScalaReflection.schemaFor[A4]).toOption ::
Try(ScalaReflection.schemaFor[A5]).toOption ::
Try(ScalaReflection.schemaFor[A6]).toOption ::
Try(ScalaReflection.schemaFor[A7]).toOption ::
Try(ScalaReflection.schemaFor[A8]).toOption ::
Try(ScalaReflection.schemaFor[A9]).toOption ::
Try(ScalaReflection.schemaFor[A10]).toOption ::
Try(ScalaReflection.schemaFor[A11]).toOption ::
Try(ScalaReflection.schemaFor[A12]).toOption ::
Try(ScalaReflection.schemaFor[A13]).toOption ::
Try(ScalaReflection.schemaFor[A14]).toOption ::
Try(ScalaReflection.schemaFor[A15]).toOption ::
Try(ScalaReflection.schemaFor[A16]).toOption ::
Try(ScalaReflection.schemaFor[A17]).toOption ::
Try(ScalaReflection.schemaFor[A18]).toOption ::
Try(ScalaReflection.schemaFor[A19]).toOption ::
Try(ScalaReflection.schemaFor[A20]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 20) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 20; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -611,17 +589,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4:
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10:
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag,
A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21:
TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9,
A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]):
UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption ::
Try(ScalaReflection.schemaFor[A4]).toOption ::
Try(ScalaReflection.schemaFor[A5]).toOption ::
Try(ScalaReflection.schemaFor[A6]).toOption ::
Try(ScalaReflection.schemaFor[A7]).toOption ::
Try(ScalaReflection.schemaFor[A8]).toOption ::
Try(ScalaReflection.schemaFor[A9]).toOption ::
Try(ScalaReflection.schemaFor[A10]).toOption ::
Try(ScalaReflection.schemaFor[A11]).toOption ::
Try(ScalaReflection.schemaFor[A12]).toOption ::
Try(ScalaReflection.schemaFor[A13]).toOption ::
Try(ScalaReflection.schemaFor[A14]).toOption ::
Try(ScalaReflection.schemaFor[A15]).toOption ::
Try(ScalaReflection.schemaFor[A16]).toOption ::
Try(ScalaReflection.schemaFor[A17]).toOption ::
Try(ScalaReflection.schemaFor[A18]).toOption ::
Try(ScalaReflection.schemaFor[A19]).toOption ::
Try(ScalaReflection.schemaFor[A20]).toOption ::
Try(ScalaReflection.schemaFor[A21]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 21) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 21; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
/**
@@ -632,17 +609,16 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4:
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10:
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag,
A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21:
TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6,
A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22,
RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] =
Try(ScalaReflection.schemaFor[A1]).toOption ::
Try(ScalaReflection.schemaFor[A2]).toOption ::
Try(ScalaReflection.schemaFor[A3]).toOption ::
Try(ScalaReflection.schemaFor[A4]).toOption ::
Try(ScalaReflection.schemaFor[A5]).toOption ::
Try(ScalaReflection.schemaFor[A6]).toOption ::
Try(ScalaReflection.schemaFor[A7]).toOption ::
Try(ScalaReflection.schemaFor[A8]).toOption ::
Try(ScalaReflection.schemaFor[A9]).toOption ::
Try(ScalaReflection.schemaFor[A10]).toOption ::
Try(ScalaReflection.schemaFor[A11]).toOption ::
Try(ScalaReflection.schemaFor[A12]).toOption ::
Try(ScalaReflection.schemaFor[A13]).toOption ::
Try(ScalaReflection.schemaFor[A14]).toOption ::
Try(ScalaReflection.schemaFor[A15]).toOption ::
Try(ScalaReflection.schemaFor[A16]).toOption ::
Try(ScalaReflection.schemaFor[A17]).toOption ::
Try(ScalaReflection.schemaFor[A18]).toOption ::
Try(ScalaReflection.schemaFor[A19]).toOption ::
Try(ScalaReflection.schemaFor[A20]).toOption ::
Try(ScalaReflection.schemaFor[A21]).toOption ::
Try(ScalaReflection.schemaFor[A22]).toOption :: Nil
+ val udf = SparkUserDefinedFunction(func, dataType,
inputSchemas).withName(name)
+ val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 22) {
- ScalaUDF(func, dataType, e,
inputSchemas.map(_.map(_.nullable).getOrElse(true)),
- if (inputSchemas.contains(None)) Nil else
inputSchemas.map(_.get.dataType),
- Some(name), nullable, udfDeterministic = true)
+ finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function "
+ name +
". Expected: 22; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = SparkUserDefinedFunction.create(func, dataType,
inputSchemas).withName(name)
- if (nullable) udf else udf.asNonNullable()
+ finalUdf
}
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index 901472d8e0360..1b2d6c7ffb529 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.expressions
import org.apache.spark.annotation.Stable
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.catalyst.expressions.ScalaUDF
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
+import org.apache.spark.sql.types.{AnyDataType, DataType}
/**
* A user-defined function. To create one, use the `udf` functions in
`functions`.
@@ -88,40 +88,47 @@ sealed abstract class UserDefinedFunction {
private[sql] case class SparkUserDefinedFunction(
f: AnyRef,
dataType: DataType,
- inputTypes: Option[Seq[DataType]],
- nullableTypes: Option[Seq[Boolean]],
+ inputSchemas: Seq[Option[ScalaReflection.Schema]],
name: Option[String] = None,
nullable: Boolean = true,
deterministic: Boolean = true) extends UserDefinedFunction {
@scala.annotation.varargs
override def apply(exprs: Column*): Column = {
- // TODO: make sure this class is only instantiated through
`SparkUserDefinedFunction.create()`
- // and `nullableTypes` is always set.
- if (inputTypes.isDefined) {
- assert(inputTypes.get.length == nullableTypes.get.length)
- }
+ Column(createScalaUDF(exprs.map(_.expr)))
+ }
+
+ private[sql] def createScalaUDF(exprs: Seq[Expression]): ScalaUDF = {
+ // It's possible that some of the inputs don't have a specific type(e.g.
`Any`), skip type
+ // check and null check for them.
+ val inputTypes = inputSchemas.map(_.map(_.dataType).getOrElse(AnyDataType))
- val inputsNullSafe = nullableTypes.getOrElse {
+ val inputsNullSafe = if (inputSchemas.isEmpty) {
+ // This is for backward compatibility of `functions.udf(AnyRef,
DataType)`. We need to
+ // do reflection of the lambda function object and see if its arguments
are nullable or not.
+ // This doesn't work for Scala 2.12 and we should consider removing this
workaround, as Spark
+ // uses Scala 2.12 by default since 3.0.
ScalaReflection.getParameterTypeNullability(f)
+ } else {
+ inputSchemas.map(_.map(_.nullable).getOrElse(true))
}
- Column(ScalaUDF(
+ ScalaUDF(
f,
dataType,
- exprs.map(_.expr),
+ exprs,
inputsNullSafe,
- inputTypes.getOrElse(Nil),
+ inputTypes,
udfName = name,
nullable = nullable,
- udfDeterministic = deterministic))
+ udfDeterministic = deterministic)
}
- override def withName(name: String): UserDefinedFunction = {
+ override def withName(name: String): SparkUserDefinedFunction = {
copy(name = Option(name))
}
- override def asNonNullable(): UserDefinedFunction = {
+ override def asNonNullable(): SparkUserDefinedFunction = {
if (!nullable) {
this
} else {
@@ -129,7 +136,7 @@ private[sql] case class SparkUserDefinedFunction(
}
}
- override def asNondeterministic(): UserDefinedFunction = {
+ override def asNondeterministic(): SparkUserDefinedFunction = {
if (!deterministic) {
this
} else {
@@ -137,19 +144,3 @@ private[sql] case class SparkUserDefinedFunction(
}
}
}
-
-private[sql] object SparkUserDefinedFunction {
-
- def create(
- f: AnyRef,
- dataType: DataType,
- inputSchemas: Seq[Option[ScalaReflection.Schema]]): UserDefinedFunction
= {
- val inputTypes = if (inputSchemas.contains(None)) {
- None
- } else {
- Some(inputSchemas.map(_.get.dataType))
- }
- val nullableTypes =
Some(inputSchemas.map(_.map(_.nullable).getOrElse(true)))
- SparkUserDefinedFunction(f, dataType, inputTypes, nullableTypes)
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 645452553e6a5..7572cf23cde8e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3874,7 +3874,7 @@ object functions {
|def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = {
| val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
| val inputSchemas = $inputSchemas
- | val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
+ | val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
| if (nullable) udf else udf.asNonNullable()
|}""".stripMargin)
}
@@ -3897,7 +3897,7 @@ object functions {
| */
|def udf(f: UDF$i[$extTypeArgs], returnType: DataType):
UserDefinedFunction = {
| val func = f$anyCast.call($anyParams)
- | SparkUserDefinedFunction.create($funcCall, returnType, inputSchemas =
Seq.fill($i)(None))
+ | SparkUserDefinedFunction($funcCall, returnType, inputSchemas =
Seq.fill($i)(None))
|}""".stripMargin)
}
@@ -3919,7 +3919,7 @@ object functions {
def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas = Nil
- val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
+ val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@@ -3935,7 +3935,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction
= {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption ::
Nil
- val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
+ val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@@ -3951,7 +3951,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]):
UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Nil
- val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
+ val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@@ -3967,7 +3967,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1,
A2, A3, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A2])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Nil
- val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
+ val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@@ -3983,7 +3983,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f:
Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A2])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A3])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Nil
- val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
+ val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@@ -3999,7 +3999,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5:
TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A2])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A3])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A4])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Nil
- val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
+ val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@@ -4015,7 +4015,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5:
TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]):
UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A2])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A3])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A4])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A5])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Nil
- val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
+ val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@@ -4031,7 +4031,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5:
TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7,
RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A2])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A3])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A4])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A5])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A6])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Nil
- val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
+ val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@@ -4047,7 +4047,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5:
TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4,
A5, A6, A7, A8, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A2])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A3])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A4])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A5])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A6])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A7])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Nil
- val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
+ val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@@ -4063,7 +4063,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5:
TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1,
A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A2])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A3])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A4])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A5])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A6])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A7])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A8])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Nil
- val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
+ val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@@ -4079,7 +4079,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5:
TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f:
Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction =
{
val ScalaReflection.Schema(dataType, nullable) =
ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A2])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A3])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A4])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A5])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A6])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A7])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A8])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A9])).toOption ::
Try(ScalaReflection.schemaFor(typeTag[A10])).toOption :: Nil
- val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
+ val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@@ -4098,7 +4098,7 @@ object functions {
*/
def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = {
val func = f.asInstanceOf[UDF0[Any]].call()
- SparkUserDefinedFunction.create(() => func, returnType, inputSchemas =
Seq.fill(0)(None))
+ SparkUserDefinedFunction(() => func, returnType, inputSchemas =
Seq.fill(0)(None))
}
/**
@@ -4112,7 +4112,7 @@ object functions {
*/
def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = {
val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any)
- SparkUserDefinedFunction.create(func, returnType, inputSchemas =
Seq.fill(1)(None))
+ SparkUserDefinedFunction(func, returnType, inputSchemas =
Seq.fill(1)(None))
}
/**
@@ -4126,7 +4126,7 @@ object functions {
*/
def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = {
val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any)
- SparkUserDefinedFunction.create(func, returnType, inputSchemas =
Seq.fill(2)(None))
+ SparkUserDefinedFunction(func, returnType, inputSchemas =
Seq.fill(2)(None))
}
/**
@@ -4140,7 +4140,7 @@ object functions {
*/
def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = {
val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any,
_: Any)
- SparkUserDefinedFunction.create(func, returnType, inputSchemas =
Seq.fill(3)(None))
+ SparkUserDefinedFunction(func, returnType, inputSchemas =
Seq.fill(3)(None))
}
/**
@@ -4154,7 +4154,7 @@ object functions {
*/
def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction =
{
val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _:
Any, _: Any, _: Any)
- SparkUserDefinedFunction.create(func, returnType, inputSchemas =
Seq.fill(4)(None))
+ SparkUserDefinedFunction(func, returnType, inputSchemas =
Seq.fill(4)(None))
}
/**
@@ -4168,7 +4168,7 @@ object functions {
*/
def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType):
UserDefinedFunction = {
val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any,
_: Any, _: Any, _: Any, _: Any)
- SparkUserDefinedFunction.create(func, returnType, inputSchemas =
Seq.fill(5)(None))
+ SparkUserDefinedFunction(func, returnType, inputSchemas =
Seq.fill(5)(None))
}
/**
@@ -4182,7 +4182,7 @@ object functions {
*/
def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType):
UserDefinedFunction = {
val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_:
Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- SparkUserDefinedFunction.create(func, returnType, inputSchemas =
Seq.fill(6)(None))
+ SparkUserDefinedFunction(func, returnType, inputSchemas =
Seq.fill(6)(None))
}
/**
@@ -4196,7 +4196,7 @@ object functions {
*/
def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType):
UserDefinedFunction = {
val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any,
Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- SparkUserDefinedFunction.create(func, returnType, inputSchemas =
Seq.fill(7)(None))
+ SparkUserDefinedFunction(func, returnType, inputSchemas =
Seq.fill(7)(None))
}
/**
@@ -4210,7 +4210,7 @@ object functions {
*/
def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType):
UserDefinedFunction = {
val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any,
Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- SparkUserDefinedFunction.create(func, returnType, inputSchemas =
Seq.fill(8)(None))
+ SparkUserDefinedFunction(func, returnType, inputSchemas =
Seq.fill(8)(None))
}
/**
@@ -4224,7 +4224,7 @@ object functions {
*/
def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType):
UserDefinedFunction = {
val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any,
_: Any)
- SparkUserDefinedFunction.create(func, returnType, inputSchemas =
Seq.fill(9)(None))
+ SparkUserDefinedFunction(func, returnType, inputSchemas =
Seq.fill(9)(None))
}
/**
@@ -4238,7 +4238,7 @@ object functions {
*/
def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType):
UserDefinedFunction = {
val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _:
Any, _: Any, _: Any)
- SparkUserDefinedFunction.create(func, returnType, inputSchemas =
Seq.fill(10)(None))
+ SparkUserDefinedFunction(func, returnType, inputSchemas =
Seq.fill(10)(None))
}
// scalastyle:on parameter.number
@@ -4257,9 +4257,7 @@ object functions {
* @since 2.0.0
*/
def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = {
- // TODO: should call SparkUserDefinedFunction.create() instead but
inputSchemas is currently
- // unavailable. We may need to create type-safe overloaded versions of
udf() methods.
- SparkUserDefinedFunction(f, dataType, inputTypes = None, nullableTypes =
None)
+ SparkUserDefinedFunction(f, dataType, inputSchemas = Nil)
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index a26d306cff6b5..06b9343c37581 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -450,4 +450,19 @@ class UDFSuite extends QueryTest with SharedSQLContext {
})
checkAnswer(df2.select(udf2($"col1")), Seq(Row(Map("a" ->
"2011000000000002456556"))))
}
+
+ test("SPARK-26323 Verify input type check - with udf()") {
+ val f = udf((x: Long, y: Any) => x)
+ val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j").select(f($"i", $"j"))
+ checkAnswer(df, Seq(Row(1L), Row(2L)))
+ }
+
+ test("SPARK-26323 Verify input type check - with udf.register") {
+ withTable("t") {
+ Seq(1 -> "a", 2 -> "b").toDF("i",
"j").write.format("json").saveAsTable("t")
+ spark.udf.register("f", (x: Long, y: Any) => x)
+ val df = spark.sql("SELECT f(i, j) FROM t")
+ checkAnswer(df, Seq(Row(1L), Row(2L)))
+ }
+ }
}
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]