Repository: spark Updated Branches: refs/heads/master 5aeb7384c -> f00df40cf
[SPARK-11775][PYSPARK][SQL] Allow PySpark to register Java UDF Currently pyspark can only call the builtin java UDF, but can not call custom java UDF. It would be better to allow that. 2 benefits: * Leverage the power of rich third party java library * Improve the performance. Because if we use python UDF, python daemons will be started on worker which will affect the performance. Author: Jeff Zhang <zjf...@apache.org> Closes #9766 from zjffdu/SPARK-11775. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f00df40c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f00df40c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f00df40c Branch: refs/heads/master Commit: f00df40cfefef0f3fc73f16ada1006e4dcfa5a39 Parents: 5aeb738 Author: Jeff Zhang <zjf...@apache.org> Authored: Fri Oct 14 15:50:35 2016 -0700 Committer: Michael Armbrust <mich...@databricks.com> Committed: Fri Oct 14 15:50:35 2016 -0700 ---------------------------------------------------------------------- python/pyspark/sql/context.py | 28 +++++++- .../spark/sql/catalyst/JavaTypeInference.scala | 2 +- .../org/apache/spark/sql/UDFRegistration.scala | 75 +++++++++++++++++++- .../org/apache/spark/sql/JavaStringLength.java | 30 ++++++++ .../test/org/apache/spark/sql/JavaUDFSuite.java | 21 ++++++ 5 files changed, 152 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f00df40c/python/pyspark/sql/context.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 8264dcf..de4c335 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -28,7 +28,7 @@ from pyspark.sql.session import _monkey_patch_RDD, SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader -from pyspark.sql.types import Row, StringType +from pyspark.sql.types import IntegerType, Row, StringType from pyspark.sql.utils import install_exception_handler __all__ = ["SQLContext", "HiveContext", "UDFRegistration"] @@ -202,6 +202,32 @@ class SQLContext(object): """ self.sparkSession.catalog.registerFunction(name, f, returnType) + @ignore_unicode_prefix + @since(2.1) + def registerJavaFunction(self, name, javaClassName, returnType=None): + """Register a java UDF so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not specified we would infer it via reflection. + :param name: name of the UDF + :param javaClassName: fully qualified name of java class + :param returnType: a :class:`pyspark.sql.types.DataType` object + + >>> sqlContext.registerJavaFunction("javaStringLength", + ... "test.org.apache.spark.sql.JavaStringLength", IntegerType()) + >>> sqlContext.sql("SELECT javaStringLength('test')").collect() + [Row(UDF(test)=4)] + >>> sqlContext.registerJavaFunction("javaStringLength2", + ... "test.org.apache.spark.sql.JavaStringLength") + >>> sqlContext.sql("SELECT javaStringLength2('test')").collect() + [Row(UDF(test)=4)] + + """ + jdt = None + if returnType is not None: + jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) + self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) + # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): """ http://git-wip-us.apache.org/repos/asf/spark/blob/f00df40c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index e6f61b0..04f0cfc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -59,7 +59,7 @@ object JavaTypeInference { * @param typeToken Java type * @return (SQL data type, nullable) */ - private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { + private[sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) http://git-wip-us.apache.org/repos/asf/spark/blob/f00df40c/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 617a147..0444ad1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,19 +17,25 @@ package org.apache.spark.sql +import java.io.IOException +import java.lang.reflect.{ParameterizedType, Type} + import scala.reflect.runtime.universe.TypeTag import scala.util.Try +import com.google.common.reflect.TypeToken + import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.api.java._ +import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, DataTypes} +import org.apache.spark.util.Utils /** * Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this. @@ -414,6 +420,71 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends ////////////////////////////////////////////////////////////////////////////////////////////// /** + * Register a Java UDF class using reflection, for use from pyspark + * + * @param name udf name + * @param className fully qualified class name of udf + * @param returnDataType return type of udf. If it is null, spark would try to infer + * via reflection. + */ + private[sql] def registerJava(name: String, className: String, returnDataType: DataType): Unit = { + + try { + val clazz = Utils.classForName(className) + val udfInterfaces = clazz.getGenericInterfaces + .filter(_.isInstanceOf[ParameterizedType]) + .map(_.asInstanceOf[ParameterizedType]) + .filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF")) + if (udfInterfaces.length == 0) { + throw new IOException(s"UDF class ${className} doesn't implement any UDF interface") + } else if (udfInterfaces.length > 1) { + throw new IOException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}") + } else { + try { + val udf = clazz.newInstance() + val udfReturnType = udfInterfaces(0).getActualTypeArguments.last + var returnType = returnDataType + if (returnType == null) { + returnType = JavaTypeInference.inferDataType(TypeToken.of(udfReturnType))._1 + } + + udfInterfaces(0).getActualTypeArguments.length match { + case 2 => register(name, udf.asInstanceOf[UDF1[_, _]], returnType) + case 3 => register(name, udf.asInstanceOf[UDF2[_, _, _]], returnType) + case 4 => register(name, udf.asInstanceOf[UDF3[_, _, _, _]], returnType) + case 5 => register(name, udf.asInstanceOf[UDF4[_, _, _, _, _]], returnType) + case 6 => register(name, udf.asInstanceOf[UDF5[_, _, _, _, _, _]], returnType) + case 7 => register(name, udf.asInstanceOf[UDF6[_, _, _, _, _, _, _]], returnType) + case 8 => register(name, udf.asInstanceOf[UDF7[_, _, _, _, _, _, _, _]], returnType) + case 9 => register(name, udf.asInstanceOf[UDF8[_, _, _, _, _, _, _, _, _]], returnType) + case 10 => register(name, udf.asInstanceOf[UDF9[_, _, _, _, _, _, _, _, _, _]], returnType) + case 11 => register(name, udf.asInstanceOf[UDF10[_, _, _, _, _, _, _, _, _, _, _]], returnType) + case 12 => register(name, udf.asInstanceOf[UDF11[_, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 13 => register(name, udf.asInstanceOf[UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 14 => register(name, udf.asInstanceOf[UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 15 => register(name, udf.asInstanceOf[UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 16 => register(name, udf.asInstanceOf[UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 17 => register(name, udf.asInstanceOf[UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 18 => register(name, udf.asInstanceOf[UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 19 => register(name, udf.asInstanceOf[UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 20 => register(name, udf.asInstanceOf[UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case n => logError(s"UDF class with ${n} type arguments is not supported ") + } + } catch { + case e @ (_: InstantiationException | _: IllegalArgumentException) => + logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") + } + } + } catch { + case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath") + } + + } + + /** * Register a user-defined function with 1 arguments. * @since 1.3.0 */ http://git-wip-us.apache.org/repos/asf/spark/blob/f00df40c/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java new file mode 100644 index 0000000..b90224f --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import org.apache.spark.sql.api.java.UDF1; + +/** + * It is used for register Java UDF from PySpark + */ +public class JavaStringLength implements UDF1<String, Integer> { + @Override + public Integer call(String str) throws Exception { + return new Integer(str.length()); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/f00df40c/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 2274912..8bf3278 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -87,4 +87,25 @@ public class JavaUDFSuite implements Serializable { Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); Assert.assertEquals(9, result.getInt(0)); } + + public static class StringLengthTest implements UDF2<String, String, Integer> { + @Override + public Integer call(String str1, String str2) throws Exception { + return new Integer(str1.length() + str2.length()); + } + } + + @SuppressWarnings("unchecked") + @Test + public void udf3Test() { + spark.udf().registerJava("stringLengthTest", StringLengthTest.class.getName(), + DataTypes.IntegerType); + Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); + Assert.assertEquals(9, result.getInt(0)); + + // returnType is not provided + spark.udf().registerJava("stringLengthTest2", StringLengthTest.class.getName(), null); + result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); + Assert.assertEquals(9, result.getInt(0)); + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org